Skip to content

Commit 140987e

Browse files
cptspacemanspiffswolchok
authored andcommitted
Added better error messages for type validators (stack trace)
1 parent 9bd18f6 commit 140987e

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

exir/verification/arg_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class EdgeOpArgValidator(torch.fx.Interpreter):
3737

3838
def __init__(self, graph_module: torch.fx.GraphModule) -> None:
3939
super().__init__(graph_module)
40-
self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = (
40+
self.violating_ops: Dict[EdgeOpOverload, Tuple[Dict[str, Optional[torch.dtype]], torch.fx.Node]] = (
4141
defaultdict(dict)
4242
)
4343

@@ -125,5 +125,5 @@ def call_function( # noqa: C901 # pyre-fixme[14]
125125

126126
valid = target._schema.dtype_constraint.validate(tensor_arg_types)
127127
if not valid:
128-
self.violating_ops[target] = tensor_arg_types
128+
self.violating_ops[target] = (tensor_arg_types, self.node)
129129
return super().call_function(target, args, kwargs) # pyre-fixme[6]

exir/verification/verifier.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,15 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
189189
return
190190

191191
if validator.violating_ops:
192+
error_msg = ""
193+
for op, node in validator.violating_ops.items():
194+
# error_msg += f"#####################################################\n"
195+
error_msg += f"Operator: {op} with args: {node[0]}\n"
196+
error_msg += f"stack trace: {node[1].stack_trace}\n\n"
197+
# error_msg += f"#####################################################\n"
192198
raise SpecViolationError(
193-
f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"
194-
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding "
199+
f"These operators are taking Tensor inputs with mismatched dtypes:\n{error_msg}\n"
200+
"Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs."
195201
)
196202

197203

0 commit comments

Comments
 (0)