Skip to content

Commit 808c83b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Calibrator: devices
Reviewed By: SS-JIA Differential Revision: D54544308 fbshipit-source-id: 8b7ef524786d566b869ff21d2b14e64690400514
1 parent 8677dab commit 808c83b

File tree

3 files changed

+95
-13
lines changed

3 files changed

+95
-13
lines changed

calibrator/runner.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
from typing import Any, List, OrderedDict, Tuple
1111

12+
import torch
1213
from executorch.exir.dialects.edge.op.api import get_callable, to_variant
1314
from inputgen.argtuple.engine import MetaArgTupleEngine
1415
from inputgen.argtuple.gen import ArgumentTupleGenerator
@@ -26,13 +27,25 @@ def smt(meta_tuple):
2627

2728

2829
class SpecRunner:
29-
def __init__(self, spec: Spec, *, valid: bool = True, out: bool = False):
30+
def __init__(
31+
self,
32+
spec: Spec,
33+
*,
34+
valid: bool = True,
35+
out: bool = False,
36+
devices: Tuple[str] = ("cpu",),
37+
):
3038
self.spec = spec
3139
self.generator = ArgumentTupleGenerator(self.spec)
3240
self.valid = valid
3341
self.out = out
3442
self.op_name = spec.op
3543
self.op = self.get_callable_op()
44+
self.results = {}
45+
self.devices = devices
46+
self.results = {}
47+
for device in self.devices:
48+
self.results[device] = {}
3649

3750
def get_callable_op(self):
3851
name = self.spec.op
@@ -42,26 +55,93 @@ def get_callable_op(self):
4255
op: OpOverload = to_variant(op, SchemaKind.out)
4356
return op
4457

45-
def run(self):
58+
def report_device(self, device):
59+
print(f"Device: {device}\n")
4660
failures = []
47-
engine = MetaArgTupleEngine(self.spec, out=self.out)
48-
for meta_tuple in engine.gen(valid=self.valid):
49-
success, _, _, _, _ = self.run_meta_tuple(meta_tuple)
61+
for meta_tuple in self.results[device]:
62+
success = self.results[device][meta_tuple]
5063
if not success:
5164
failures.append(meta_tuple)
5265
if len(failures) > 0:
5366
print("FAILURES\n")
5467
for meta_tuple in failures:
55-
print(f"\t{smt(meta_tuple)}\n")
68+
print(f"\t{meta_tuple}\n")
5669
else:
5770
print("SUCCESS\n")
5871

72+
def report_inconsistencies(self):
73+
print(f"Devices: {' '.join(self.devices)}\n")
74+
meta_tuples = self.results[self.devices[0]].keys()
75+
inconsistencies = set()
76+
for meta_tuple in meta_tuples:
77+
res = self.results[self.devices[0]][meta_tuple]
78+
for device in self.devices[1:]:
79+
res ^= self.results[device][meta_tuple]
80+
if not res:
81+
inconsistencies.append(meta_tuple)
82+
if len(inconsistencies) > 0:
83+
print("INCONSISTENCIES\n")
84+
for meta_tuple in inconsistencies:
85+
res = [self.results[d][meta_tuple] for d in self.devices]
86+
res_string = " ".join(["x" if r else "o" for r in res])
87+
print(f"\t{res_string} {meta_tuple}\n")
88+
89+
def run(self):
90+
engine = MetaArgTupleEngine(self.spec, out=self.out)
91+
for meta_tuple in engine.gen(valid=self.valid):
92+
self.run_meta_tuple(meta_tuple)
93+
if len(self.devices) > 1:
94+
self.report_inconsistencies()
95+
for device in self.devices:
96+
self.report_device(device)
97+
98+
def move_to_device(
99+
self,
100+
device: str,
101+
cpu_posargs: List[Any],
102+
cpu_inkwargs: OrderedDict[str, Any],
103+
cpu_outargs: OrderedDict[str, Any],
104+
):
105+
if device == "cpu":
106+
return cpu_posargs, cpu_inkwargs, cpu_outargs
107+
posargs = []
108+
inkwargs = OrderedDict()
109+
outargs = OrderedDict()
110+
for arg in cpu_posargs:
111+
new = arg
112+
if isinstance(arg, torch.Tensor):
113+
new = arg.to(device=device)
114+
posargs.append(new)
115+
for k, v in cpu_inkwargs.items():
116+
new = v
117+
if isinstance(v, torch.Tensor):
118+
new = v.to(device=device)
119+
inkwargs[k] = new
120+
for k, v in cpu_outargs.items():
121+
new = v
122+
if isinstance(v, torch.Tensor):
123+
new = v.to(device=device)
124+
outargs[k] = new
125+
return posargs, inkwargs, outargs
126+
59127
def run_meta_tuple(
60128
self, meta_tuple: Tuple[MetaArg]
61129
) -> Tuple[bool, Any, List[Any], OrderedDict[str, Any], OrderedDict[str, Any]]:
62130
print(f"Running op: {self.op_name}, meta_tuple: {[str(x) for x in meta_tuple]}")
63131
posargs, inkwargs, outargs = self.generator.gen_tuple(meta_tuple, out=self.out)
64-
return self.run_values(meta_tuple, posargs, inkwargs, outargs)
132+
for device in self.devices:
133+
posargs, inkwargs, outargs = self.move_to_device(
134+
device, posargs, inkwargs, outargs
135+
)
136+
success, res, posargs, inkwargs, outargs = self.run_values(
137+
meta_tuple, posargs, inkwargs, outargs
138+
)
139+
mt = smt(meta_tuple)
140+
if mt in self.results[device]:
141+
logging.warning(f"Repeated meta_tuple {mt}")
142+
self.results[device][mt] &= success
143+
else:
144+
self.results[device][mt] = success
65145

66146
def run_values(
67147
self,
@@ -96,13 +176,14 @@ def main():
96176
"--invalid", action="store_true", help="generate invalid inputs"
97177
)
98178
parser.add_argument("--out", action="store_true", help="run out variants")
179+
parser.add_argument("--devices", nargs="*", default=("cpu",), help="run on devices")
99180
args = parser.parse_args()
100181

101182
if args.op not in SpecDictDB:
102183
raise RuntimeError(f"Op {args.op} not found in SpecDB")
103184

104185
spec = SpecDictDB[args.op]
105-
SpecRunner(spec, valid=not args.invalid, out=args.out).run()
186+
SpecRunner(spec, valid=not args.invalid, out=args.out, devices=args.devices).run()
106187

107188

108189
if __name__ == "__main__":

inputgen/argtuple/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, spec: Spec):
1818

1919
def gen_tuple(
2020
self, meta_tuple: Tuple[MetaArg], *, out: bool = False
21-
) -> Tuple[List[Any], OrderedDict[str, Any]]:
21+
) -> Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]]:
2222
posargs = []
2323
inkwargs = OrderedDict()
2424
outargs = OrderedDict()

specdb/db.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -997,9 +997,10 @@
997997
name="bias",
998998
deps=[1, 6, 8],
999999
constraints=[
1000-
cp.Dtype.In(
1001-
lambda deps: [deps[0].dtype] if deps[1] else dt._floating
1002-
), # TODO(mcandales): Calibrate.
1000+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
1001+
# cp.Dtype.In(
1002+
# lambda deps: [deps[0].dtype] if deps[1] else dt._floating
1003+
# ), # TODO(mcandales): Calibrate.
10031004
cp.Rank.Eq(lambda deps: 1),
10041005
cp.Size.Eq(
10051006
lambda deps, r, d: fn.conv_bias_size_eq(
@@ -1039,7 +1040,7 @@
10391040
ArgType.Bool,
10401041
name="transposed",
10411042
# TODO(mcandales): Executorch specific constraint
1042-
# constraints=[cp.Value.In(lambda deps: [False]]),
1043+
# constraints=[cp.Value.Eq(lambda deps: False)],
10431044
),
10441045
InPosArg( # output_padding
10451046
ArgType.LengthList,

0 commit comments

Comments
 (0)