Skip to content

Commit 0786faa

Browse files
authored
[aoti-cuda] Directly pass user input placeholders to torch._inductor.aot_compile (#14707)
torch._inductor.aot_compile Summary: As titled, this avoid issues like symint Test Plan: Reviewers: Subscribers: Tasks: Tags: ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent 6a238e3 commit 0786faa

File tree

1 file changed

+1
-13
lines changed

1 file changed

+1
-13
lines changed

backends/cuda/cuda_backend.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,6 @@ def preprocess(
117117
if node.op == "placeholder" and node.name in user_input_names:
118118
user_input_placeholders.append(node.meta["val"])
119119

120-
# Create pseudo user inputs using torch.randn and metadata from input placeholders
121-
faked_user_inputs = []
122-
for placeholder in user_input_placeholders:
123-
if isinstance(placeholder, torch.Tensor):
124-
# Generate fake input with same shape and dtype, on CUDA
125-
fake_input = torch.randn(
126-
placeholder.shape, dtype=placeholder.dtype, device="cuda"
127-
)
128-
faked_user_inputs.append(fake_input)
129-
130-
faked_user_inputs = tuple(faked_user_inputs)
131-
132120
options: dict[str, typing.Any] = {
133121
# Embed CUDA kernel binaries directly into the compiled shared object
134122
"aot_inductor.embed_kernel_binary": True,
@@ -145,7 +133,7 @@ def preprocess(
145133
}
146134

147135
with collect_unsupported_fallback_kernels():
148-
so_path = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options) # type: ignore[arg-type]
136+
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
149137
if len(missing_fallback_kernels) > 0:
150138
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
151139
raise RuntimeError(

0 commit comments

Comments
 (0)