Skip to content

Commit 78fcdc6

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Ensure the correct output data type for the full op.
Summary: Fixes tests Differential Revision: D80125213
1 parent 2e44b3c commit 78fcdc6

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2327,10 +2327,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23272327
# Cast the const_arg to the dtype of the x_arg
23282328
full_arg = self.resolve_full_arg(x_arg, const_arg)
23292329

2330+
full_output_dtype = torch.int32 if isinstance(full_arg, int) else torch.float32
2331+
23302332
# Extract an argument to a separate full op.
23312333
with graph_module.graph.inserting_before(mul_node):
23322334
full_node = graph_module.graph.call_function(
2333-
torch.ops.aten.full.default, args=([1], full_arg)
2335+
torch.ops.aten.full.default, args=([1], full_arg),
2336+
kwargs={"dtype": full_output_dtype},
23342337
)
23352338
full_node.meta = mul_node.meta
23362339
full_node.meta["val"] = [1]

0 commit comments

Comments
 (0)