diff --git a/exir/tests/test_arg_validator.py b/exir/tests/test_arg_validator.py index d85ef81b90d..ede8b224329 100644 --- a/exir/tests/test_arg_validator.py +++ b/exir/tests/test_arg_validator.py @@ -64,7 +64,7 @@ def forward(self, x): ops.edge.aten._log_softmax.default.name(), ) self.assertDictEqual( - validator.violating_ops[key], + validator.violating_ops[key][0], { "self": torch.bfloat16, "__ret_0": torch.bfloat16, diff --git a/exir/verification/arg_validator.py b/exir/verification/arg_validator.py index c087944b12d..53e2d36d390 100644 --- a/exir/verification/arg_validator.py +++ b/exir/verification/arg_validator.py @@ -37,9 +37,9 @@ class EdgeOpArgValidator(torch.fx.Interpreter): def __init__(self, graph_module: torch.fx.GraphModule) -> None: super().__init__(graph_module) - self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = ( - defaultdict(dict) - ) + self.violating_ops: Dict[ + EdgeOpOverload, Tuple[Dict[str, Optional[torch.dtype]], torch.fx.Node] + ] = defaultdict(dict) def run_node(self, n: torch.fx.Node) -> None: self.node = n @@ -125,5 +125,5 @@ def call_function( # noqa: C901 # pyre-fixme[14] valid = target._schema.dtype_constraint.validate(tensor_arg_types) if not valid: - self.violating_ops[target] = tensor_arg_types + self.violating_ops[target] = (tensor_arg_types, self.node) return super().call_function(target, args, kwargs) # pyre-fixme[6] diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index 2ad453ffede..166e6a758a5 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -189,9 +189,15 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: return if validator.violating_ops: + error_msg = "" + for op, node in validator.violating_ops.items(): + # error_msg += f"#####################################################\n" + error_msg += f"\nOperator: {op} with args: {node[0]}\n" + error_msg += f"stack trace: {node[1].stack_trace}\n" + # error_msg += f"#####################################################\n" raise SpecViolationError( - f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}" - "Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding " + f"These operators are taking Tensor inputs with mismatched dtypes:\n{error_msg}" + "Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs." )