Skip to content

Commit a1076af

Browse files
authored
Fix visualization bug
Differential Revision: D90275311 Pull Request resolved: #16497
1 parent 41892be commit a1076af

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,17 +1458,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
14581458
# kernel_width for conv1d, and X = kernel_height * kernel_width for
14591459
# conv2d. We extract X as the kernel_size for im2row.
14601460
kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:])
1461-
# If the transposed_convolution op was quantized, we need the input tensor's
1462-
# zero_point for im2row. Otherwise in_zero_point defaults to a zero
1463-
# tensor. For quantized ops, the input_zero_point is at args[8] as an int,
1464-
# and we need to convert it to a tensor since transposed_im2row expects
1465-
# a Tensor for in_zero_point (unlike im2row which expects an int).
14661461
assert isinstance(in_tensor, torch.fx.Node)
1467-
in_zero_point = (
1468-
torch.tensor(node.args[8], dtype=torch.int32)
1469-
if quantized_op
1470-
else torch.tensor(0, dtype=torch.int32)
1471-
)
14721462

14731463
# Cast to appropriate types
14741464
stride = cast(Sequence[int], stride)
@@ -1493,6 +1483,20 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
14931483

14941484
graph = node.graph
14951485

1486+
# If the transposed_convolution op was quantized, we need the input tensor's
1487+
# zero_point for im2row. Otherwise in_zero_point defaults to a zero tensor.
1488+
# We create the tensor as a graph node using aten.full to avoid
1489+
# DataDependentOutputException during FakeTensor shape propagation.
1490+
in_zero_point_val = node.args[8] if quantized_op else 0
1491+
in_zero_point = graph.call_function(
1492+
exir_ops.edge.aten.full.default,
1493+
args=([1], in_zero_point_val),
1494+
kwargs={"dtype": torch.int32},
1495+
)
1496+
# Insert the node before the current node
1497+
node.prepend(in_zero_point)
1498+
in_zero_point.meta = node.meta
1499+
14961500
# Create a transposed_im2row node with the input. This will create a 2d
14971501
# matrix of shape [out_height*out_weight, X*in_channels]. X is as
14981502
# defined in the comment above.

0 commit comments

Comments
 (0)