Skip to content

Commit f753ab5

Browse files
Update calibrator SpecRunner
ghstack-source-id: e2c95b9 Pull-Request: #35
1 parent 868ca87 commit f753ab5

File tree

1 file changed

+44
-25
lines changed

1 file changed

+44
-25
lines changed

calibrator/runner.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
import argparse
88
import logging
99
import sys
10-
from typing import Any, List, OrderedDict, Tuple
10+
from typing import Any, List, Optional, OrderedDict, Tuple
1111

1212
import torch
13-
from executorch.exir.dialects.edge.op.api import get_callable, to_variant
1413
from facto.inputgen.argtuple.engine import MetaArgTupleEngine
1514
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
1615
from facto.inputgen.argument.engine import MetaArg
1716
from facto.inputgen.specs.model import Spec
17+
from facto.inputgen.utils.config import TensorConfig
1818
from facto.specdb.db import SpecDictDB
19+
from facto.utils.ops import get_op_overload
1920
from torch._ops import OpOverload
20-
from torchgen.model import SchemaKind
2121

2222
logging.basicConfig(stream=sys.stderr, level=logging.WARNING)
2323

@@ -29,32 +29,27 @@ def smt(meta_tuple):
2929
class SpecRunner:
3030
def __init__(
3131
self,
32+
op: OpOverload,
3233
spec: Spec,
3334
*,
3435
valid: bool = True,
3536
out: bool = False,
3637
devices: Tuple[str] = ("cpu",),
38+
config: Optional[TensorConfig] = None,
3739
):
3840
self.spec = spec
39-
self.generator = ArgumentTupleGenerator(self.spec)
41+
self.config = config
42+
self.generator = ArgumentTupleGenerator(self.spec, config=config)
4043
self.valid = valid
4144
self.out = out
42-
self.op_name = spec.op
43-
self.op = self.get_callable_op()
45+
self.op_name = op.__name__
46+
self.op = op
4447
self.results = {}
4548
self.devices = devices
4649
self.results = {}
4750
for device in self.devices:
4851
self.results[device] = {}
4952

50-
def get_callable_op(self):
51-
name = self.spec.op
52-
op: OpOverload = get_callable(name)
53-
if self.out:
54-
# Get the out variant op
55-
op: OpOverload = to_variant(op, SchemaKind.out)
56-
return op
57-
5853
def report_device(self, device):
5954
print(f"Device: {device}\n")
6055
failures = []
@@ -78,7 +73,7 @@ def report_inconsistencies(self):
7873
for device in self.devices[1:]:
7974
res ^= self.results[device][meta_tuple]
8075
if not res:
81-
inconsistencies.append(meta_tuple)
76+
inconsistencies.add(meta_tuple)
8277
if len(inconsistencies) > 0:
8378
print("INCONSISTENCIES\n")
8479
for meta_tuple in inconsistencies:
@@ -98,26 +93,26 @@ def run(self):
9893
def move_to_device(
9994
self,
10095
device: str,
101-
cpu_posargs: List[Any],
102-
cpu_inkwargs: OrderedDict[str, Any],
103-
cpu_outargs: OrderedDict[str, Any],
96+
src_posargs: List[Any],
97+
src_inkwargs: OrderedDict[str, Any],
98+
src_outargs: OrderedDict[str, Any],
10499
):
105-
if device == "cpu":
106-
return cpu_posargs, cpu_inkwargs, cpu_outargs
100+
if device == ("cpu" if self.config is None else self.config.device):
101+
return src_posargs, src_inkwargs, src_outargs
107102
posargs = []
108103
inkwargs = OrderedDict()
109104
outargs = OrderedDict()
110-
for arg in cpu_posargs:
105+
for arg in src_posargs:
111106
new = arg
112107
if isinstance(arg, torch.Tensor):
113108
new = arg.to(device=device)
114109
posargs.append(new)
115-
for k, v in cpu_inkwargs.items():
110+
for k, v in src_inkwargs.items():
116111
new = v
117112
if isinstance(v, torch.Tensor):
118113
new = v.to(device=device)
119114
inkwargs[k] = new
120-
for k, v in cpu_outargs.items():
115+
for k, v in src_outargs.items():
121116
new = v
122117
if isinstance(v, torch.Tensor):
123118
new = v.to(device=device)
@@ -133,9 +128,30 @@ def run_meta_tuple(
133128
posargs, inkwargs, outargs = self.move_to_device(
134129
device, posargs, inkwargs, outargs
135130
)
136-
success, res, posargs, inkwargs, outargs = self.run_values(
131+
success, res, res_posargs, res_inkwargs, res_outargs = self.run_values(
137132
meta_tuple, posargs, inkwargs, outargs
138133
)
134+
if (
135+
self.valid
136+
and success
137+
and device != "cpu"
138+
and isinstance(res, torch.Tensor)
139+
):
140+
cpu_posargs, cpu_inkwargs, cpu_outargs = self.move_to_device(
141+
"cpu", posargs, inkwargs, outargs
142+
)
143+
(
144+
cpu_success,
145+
cpu_res,
146+
cpu_res_posargs,
147+
cpu_res_inkwargs,
148+
cpu_res_outargs,
149+
) = self.run_values(meta_tuple, cpu_posargs, cpu_inkwargs, cpu_outargs)
150+
if cpu_success and cpu_res is not None:
151+
if not torch.allclose(cpu_res, res.to("cpu")):
152+
logging.warning(
153+
f"NOT ALL CLOSE opname: {self.op_name}, meta_tuple: {smt(meta_tuple)}, device: {device}, {(cpu_res.to(torch.float) - res.to('cpu').to(torch.float)).abs().max()}"
154+
)
139155
mt = smt(meta_tuple)
140156
if mt in self.results[device]:
141157
logging.warning(f"Repeated meta_tuple {mt}")
@@ -183,7 +199,10 @@ def main():
183199
raise RuntimeError(f"Op {args.op} not found in SpecDB")
184200

185201
spec = SpecDictDB[args.op]
186-
SpecRunner(spec, valid=not args.invalid, out=args.out, devices=args.devices).run()
202+
op = get_op_overload(args.op)
203+
SpecRunner(
204+
op, spec, valid=not args.invalid, out=args.out, devices=args.devices
205+
).run()
187206

188207

189208
if __name__ == "__main__":

0 commit comments

Comments
 (0)