Skip to content

Commit ecdb913

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Fix visualization bug
Summary: In previous approach, faketensor shape propagation converted constant to fake tensor, which meant no data for visualization. Now, value encoded as scalar arg in graph, and can compute output metadata without needing actual tensor data. If downstream op needs value, can trace back through graph and see it came from full(0) - preserved in graph structure, not in tensor data. Reviewed By: skrtskrtfb Differential Revision: D90275311
1 parent 3090486 commit ecdb913

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,17 +1458,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
14581458
# kernel_width for conv1d, and X = kernel_height * kernel_width for
14591459
# conv2d. We extract X as the kernel_size for im2row.
14601460
kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:])
1461-
# If the transposed_convolution op was quantized, we need the input tensor's
1462-
# zero_point for im2row. Otherwise in_zero_point defaults to a zero
1463-
# tensor. For quantized ops, the input_zero_point is at args[8] as an int,
1464-
# and we need to convert it to a tensor since transposed_im2row expects
1465-
# a Tensor for in_zero_point (unlike im2row which expects an int).
14661461
assert isinstance(in_tensor, torch.fx.Node)
1467-
in_zero_point = (
1468-
torch.tensor(node.args[8], dtype=torch.int32)
1469-
if quantized_op
1470-
else torch.tensor(0, dtype=torch.int32)
1471-
)
14721462

14731463
# Cast to appropriate types
14741464
stride = cast(Sequence[int], stride)
@@ -1493,6 +1483,20 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
14931483

14941484
graph = node.graph
14951485

1486+
# If the transposed_convolution op was quantized, we need the input tensor's
1487+
# zero_point for im2row. Otherwise in_zero_point defaults to a zero tensor.
1488+
# We create the tensor as a graph node using aten.full to avoid
1489+
# DataDependentOutputException during FakeTensor shape propagation.
1490+
in_zero_point_val = node.args[8] if quantized_op else 0
1491+
in_zero_point = graph.call_function(
1492+
exir_ops.edge.aten.full.default,
1493+
args=([1], in_zero_point_val),
1494+
kwargs={"dtype": torch.int32},
1495+
)
1496+
# Insert the node before the current node
1497+
node.prepend(in_zero_point)
1498+
in_zero_point.meta = node.meta
1499+
14961500
# Create a transposed_im2row node with the input. This will create a 2d
14971501
# matrix of shape [out_height*out_weight, X*in_channels]. X is as
14981502
# defined in the comment above.

0 commit comments

Comments
 (0)