From ae8d647f50daad470eafed0739795f441d52c9af Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 18 Aug 2025 13:27:07 -0400 Subject: [PATCH] Update [ghstack-poisoned] --- calibrator/runner.py | 69 ++++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/calibrator/runner.py b/calibrator/runner.py index cd607ab..9edf9f1 100644 --- a/calibrator/runner.py +++ b/calibrator/runner.py @@ -7,17 +7,17 @@ import argparse import logging import sys -from typing import Any, List, OrderedDict, Tuple +from typing import Any, List, Optional, OrderedDict, Tuple import torch -from executorch.exir.dialects.edge.op.api import get_callable, to_variant from facto.inputgen.argtuple.engine import MetaArgTupleEngine from facto.inputgen.argtuple.gen import ArgumentTupleGenerator from facto.inputgen.argument.engine import MetaArg from facto.inputgen.specs.model import Spec +from facto.inputgen.utils.config import TensorConfig from facto.specdb.db import SpecDictDB +from facto.utils.ops import get_op_overload from torch._ops import OpOverload -from torchgen.model import SchemaKind logging.basicConfig(stream=sys.stderr, level=logging.WARNING) @@ -29,32 +29,27 @@ def smt(meta_tuple): class SpecRunner: def __init__( self, + op: OpOverload, spec: Spec, *, valid: bool = True, out: bool = False, devices: Tuple[str] = ("cpu",), + config: Optional[TensorConfig] = None, ): self.spec = spec - self.generator = ArgumentTupleGenerator(self.spec) + self.config = config + self.generator = ArgumentTupleGenerator(self.spec, config=config) self.valid = valid self.out = out - self.op_name = spec.op - self.op = self.get_callable_op() + self.op_name = op.__name__ + self.op = op self.results = {} self.devices = devices self.results = {} for device in self.devices: self.results[device] = {} - def get_callable_op(self): - name = self.spec.op - op: OpOverload = get_callable(name) - if self.out: - # Get the out variant op - op: OpOverload = to_variant(op, SchemaKind.out) - return op - def report_device(self, device): print(f"Device: {device}\n") failures = [] @@ -78,7 +73,7 @@ def report_inconsistencies(self): for device in self.devices[1:]: res ^= self.results[device][meta_tuple] if not res: - inconsistencies.append(meta_tuple) + inconsistencies.add(meta_tuple) if len(inconsistencies) > 0: print("INCONSISTENCIES\n") for meta_tuple in inconsistencies: @@ -98,26 +93,26 @@ def run(self): def move_to_device( self, device: str, - cpu_posargs: List[Any], - cpu_inkwargs: OrderedDict[str, Any], - cpu_outargs: OrderedDict[str, Any], + src_posargs: List[Any], + src_inkwargs: OrderedDict[str, Any], + src_outargs: OrderedDict[str, Any], ): - if device == "cpu": - return cpu_posargs, cpu_inkwargs, cpu_outargs + if device == ("cpu" if self.config is None else self.config.device): + return src_posargs, src_inkwargs, src_outargs posargs = [] inkwargs = OrderedDict() outargs = OrderedDict() - for arg in cpu_posargs: + for arg in src_posargs: new = arg if isinstance(arg, torch.Tensor): new = arg.to(device=device) posargs.append(new) - for k, v in cpu_inkwargs.items(): + for k, v in src_inkwargs.items(): new = v if isinstance(v, torch.Tensor): new = v.to(device=device) inkwargs[k] = new - for k, v in cpu_outargs.items(): + for k, v in src_outargs.items(): new = v if isinstance(v, torch.Tensor): new = v.to(device=device) @@ -133,9 +128,30 @@ def run_meta_tuple( posargs, inkwargs, outargs = self.move_to_device( device, posargs, inkwargs, outargs ) - success, res, posargs, inkwargs, outargs = self.run_values( + success, res, res_posargs, res_inkwargs, res_outargs = self.run_values( meta_tuple, posargs, inkwargs, outargs ) + if ( + self.valid + and success + and device != "cpu" + and isinstance(res, torch.Tensor) + ): + cpu_posargs, cpu_inkwargs, cpu_outargs = self.move_to_device( + "cpu", posargs, inkwargs, outargs + ) + ( + cpu_success, + cpu_res, + cpu_res_posargs, + cpu_res_inkwargs, + cpu_res_outargs, + ) = self.run_values(meta_tuple, cpu_posargs, cpu_inkwargs, cpu_outargs) + if cpu_success and cpu_res is not None: + if not torch.allclose(cpu_res, res.to("cpu")): + logging.warning( + f"NOT ALL CLOSE opname: {self.op_name}, meta_tuple: {smt(meta_tuple)}, device: {device}, {(cpu_res.to(torch.float) - res.to('cpu').to(torch.float)).abs().max()}" + ) mt = smt(meta_tuple) if mt in self.results[device]: logging.warning(f"Repeated meta_tuple {mt}") @@ -183,7 +199,10 @@ def main(): raise RuntimeError(f"Op {args.op} not found in SpecDB") spec = SpecDictDB[args.op] - SpecRunner(spec, valid=not args.invalid, out=args.out, devices=args.devices).run() + op = get_op_overload(args.op) + SpecRunner( + op, spec, valid=not args.invalid, out=args.out, devices=args.devices + ).run() if __name__ == "__main__":