From ecdb9139ba26f2ab7b2a41a08050977f54bc2bef Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 7 Jan 2026 13:51:04 -0800 Subject: [PATCH] Fix visualization bug Summary: In previous approach, faketensor shape propagation converted constant to fake tensor, which meant no data for visualization. Now, value encoded as scalar arg in graph, and can compute output metadata without needing actual tensor data. If downstream op needs value, can trace back through graph and see it came from full(0) - preserved in graph structure, not in tensor data. Reviewed By: skrtskrtfb Differential Revision: D90275311 --- backends/cadence/aot/replace_ops.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 72d7fc164c4..5cf36bdc605 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1458,17 +1458,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # kernel_width for conv1d, and X = kernel_height * kernel_width for # conv2d. We extract X as the kernel_size for im2row. kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:]) - # If the transposed_convolution op was quantized, we need the input tensor's - # zero_point for im2row. Otherwise in_zero_point defaults to a zero - # tensor. For quantized ops, the input_zero_point is at args[8] as an int, - # and we need to convert it to a tensor since transposed_im2row expects - # a Tensor for in_zero_point (unlike im2row which expects an int). assert isinstance(in_tensor, torch.fx.Node) - in_zero_point = ( - torch.tensor(node.args[8], dtype=torch.int32) - if quantized_op - else torch.tensor(0, dtype=torch.int32) - ) # Cast to appropriate types stride = cast(Sequence[int], stride) @@ -1493,6 +1483,20 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: graph = node.graph + # If the transposed_convolution op was quantized, we need the input tensor's + # zero_point for im2row. Otherwise in_zero_point defaults to a zero tensor. + # We create the tensor as a graph node using aten.full to avoid + # DataDependentOutputException during FakeTensor shape propagation. + in_zero_point_val = node.args[8] if quantized_op else 0 + in_zero_point = graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], in_zero_point_val), + kwargs={"dtype": torch.int32}, + ) + # Insert the node before the current node + node.prepend(in_zero_point) + in_zero_point.meta = node.meta + # Create a transposed_im2row node with the input. This will create a 2d # matrix of shape [out_height*out_weight, X*in_channels]. X is as # defined in the comment above.