Skip to content

Commit 46dd51a

Browse files
authored
Ensure the correct output data type for the full op.
Differential Revision: D80125213 Pull Request resolved: #13349
1 parent 7f8f9d7 commit 46dd51a

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2327,10 +2327,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23272327
# Cast the const_arg to the dtype of the x_arg
23282328
full_arg = self.resolve_full_arg(x_arg, const_arg)
23292329

2330+
full_output_dtype = (
2331+
torch.int32 if isinstance(full_arg, int) else torch.float32
2332+
)
2333+
23302334
# Extract an argument to a separate full op.
23312335
with graph_module.graph.inserting_before(mul_node):
23322336
full_node = graph_module.graph.call_function(
2333-
torch.ops.aten.full.default, args=([1], full_arg)
2337+
torch.ops.aten.full.default,
2338+
args=([1], full_arg),
2339+
kwargs={"dtype": full_output_dtype},
23342340
)
23352341
full_node.meta = mul_node.meta
23362342
full_node.meta["val"] = [1]

0 commit comments

Comments
 (0)