Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions calibrator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
import argparse
import logging
import sys
from typing import Any, List, OrderedDict, Tuple
from typing import Any, List, Optional, OrderedDict, Tuple

import torch
from executorch.exir.dialects.edge.op.api import get_callable, to_variant
from facto.inputgen.argtuple.engine import MetaArgTupleEngine
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
from facto.inputgen.argument.engine import MetaArg
from facto.inputgen.specs.model import Spec
from facto.inputgen.utils.config import TensorConfig
from facto.specdb.db import SpecDictDB
from facto.utils.ops import get_op_overload
from torch._ops import OpOverload
from torchgen.model import SchemaKind

logging.basicConfig(stream=sys.stderr, level=logging.WARNING)

Expand All @@ -29,32 +29,27 @@ def smt(meta_tuple):
class SpecRunner:
def __init__(
self,
op: OpOverload,
spec: Spec,
*,
valid: bool = True,
out: bool = False,
devices: Tuple[str] = ("cpu",),
config: Optional[TensorConfig] = None,
):
self.spec = spec
self.generator = ArgumentTupleGenerator(self.spec)
self.config = config
self.generator = ArgumentTupleGenerator(self.spec, config=config)
self.valid = valid
self.out = out
self.op_name = spec.op
self.op = self.get_callable_op()
self.op_name = op.__name__
self.op = op
self.results = {}
self.devices = devices
self.results = {}
for device in self.devices:
self.results[device] = {}

def get_callable_op(self):
name = self.spec.op
op: OpOverload = get_callable(name)
if self.out:
# Get the out variant op
op: OpOverload = to_variant(op, SchemaKind.out)
return op

def report_device(self, device):
print(f"Device: {device}\n")
failures = []
Expand All @@ -78,7 +73,7 @@ def report_inconsistencies(self):
for device in self.devices[1:]:
res ^= self.results[device][meta_tuple]
if not res:
inconsistencies.append(meta_tuple)
inconsistencies.add(meta_tuple)
if len(inconsistencies) > 0:
print("INCONSISTENCIES\n")
for meta_tuple in inconsistencies:
Expand All @@ -98,26 +93,26 @@ def run(self):
def move_to_device(
self,
device: str,
cpu_posargs: List[Any],
cpu_inkwargs: OrderedDict[str, Any],
cpu_outargs: OrderedDict[str, Any],
src_posargs: List[Any],
src_inkwargs: OrderedDict[str, Any],
src_outargs: OrderedDict[str, Any],
):
if device == "cpu":
return cpu_posargs, cpu_inkwargs, cpu_outargs
if device == ("cpu" if self.config is None else self.config.device):
return src_posargs, src_inkwargs, src_outargs
posargs = []
inkwargs = OrderedDict()
outargs = OrderedDict()
for arg in cpu_posargs:
for arg in src_posargs:
new = arg
if isinstance(arg, torch.Tensor):
new = arg.to(device=device)
posargs.append(new)
for k, v in cpu_inkwargs.items():
for k, v in src_inkwargs.items():
new = v
if isinstance(v, torch.Tensor):
new = v.to(device=device)
inkwargs[k] = new
for k, v in cpu_outargs.items():
for k, v in src_outargs.items():
new = v
if isinstance(v, torch.Tensor):
new = v.to(device=device)
Expand All @@ -133,9 +128,30 @@ def run_meta_tuple(
posargs, inkwargs, outargs = self.move_to_device(
device, posargs, inkwargs, outargs
)
success, res, posargs, inkwargs, outargs = self.run_values(
success, res, res_posargs, res_inkwargs, res_outargs = self.run_values(
meta_tuple, posargs, inkwargs, outargs
)
if (
self.valid
and success
and device != "cpu"
and isinstance(res, torch.Tensor)
):
cpu_posargs, cpu_inkwargs, cpu_outargs = self.move_to_device(
"cpu", posargs, inkwargs, outargs
)
(
cpu_success,
cpu_res,
cpu_res_posargs,
cpu_res_inkwargs,
cpu_res_outargs,
) = self.run_values(meta_tuple, cpu_posargs, cpu_inkwargs, cpu_outargs)
if cpu_success and cpu_res is not None:
if not torch.allclose(cpu_res, res.to("cpu")):
logging.warning(
f"NOT ALL CLOSE opname: {self.op_name}, meta_tuple: {smt(meta_tuple)}, device: {device}, {(cpu_res.to(torch.float) - res.to('cpu').to(torch.float)).abs().max()}"
)
mt = smt(meta_tuple)
if mt in self.results[device]:
logging.warning(f"Repeated meta_tuple {mt}")
Expand Down Expand Up @@ -183,7 +199,10 @@ def main():
raise RuntimeError(f"Op {args.op} not found in SpecDB")

spec = SpecDictDB[args.op]
SpecRunner(spec, valid=not args.invalid, out=args.out, devices=args.devices).run()
op = get_op_overload(args.op)
SpecRunner(
op, spec, valid=not args.invalid, out=args.out, devices=args.devices
).run()


if __name__ == "__main__":
Expand Down