Skip to content

Commit b352408

Browse files
muchulee8pytorchmergebot
authored andcommitted
[AOTInductor] Generate kernels separately for const graph and main graph (pytorch#153040)
Summary: We should generate the kernel for const graph and main graph separately. The reason is that when we run autotuning, we would create separate kernel calls and we should make sure that main graph also contains the runner. Test Plan: python test/inductor/test_aot_inductor.py -k test_autotune_with_constant_folding Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D74347765](https://our.internmc.facebook.com/intern/diff/D74347765) Pull Request resolved: pytorch#153040 Approved by: https://github.com/angelayi
1 parent e5f8699 commit b352408

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,31 @@ def forward(self, x):
320320
with config.patch({"aot_inductor.use_runtime_constant_folding": True}):
321321
self.check_model(Model(self.device), example_inputs)
322322

323+
def test_autotune_with_constant_folding(self):
324+
class Model(torch.nn.Module):
325+
def __init__(self, device) -> None:
326+
super().__init__()
327+
self.x = torch.randn(2048, 2048, dtype=torch.float16, device=device)
328+
329+
def _quantize(self, input):
330+
return torch.abs(input)
331+
332+
def forward(self, y):
333+
abs_weight = self._quantize(self.x)
334+
abs_y = self._quantize(y)
335+
336+
return abs_weight, abs_y
337+
338+
input1 = (torch.rand(2048, 2048, dtype=torch.float16, device=self.device),)
339+
model = Model(self.device).to(self.device)
340+
341+
_ = model(*input1)
342+
343+
ep = torch.export.export(model, input1, dynamic_shapes=None, strict=False)
344+
torch._inductor.aoti_compile_and_package(
345+
ep, inductor_configs={"aot_inductor.use_runtime_constant_folding": True}
346+
)
347+
323348
@requires_gpu
324349
def test_multi_device(self):
325350
if self.device == "cpu" and GPU_TYPE == "xpu":

torch/_inductor/graph.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,12 +1947,7 @@ def init_wrapper_code(
19471947
)
19481948

19491949
if self.const_module:
1950-
# If we have const module, we could reuse the kernels
1951-
# This could avoid duplication and save time on doing recompilation (if Triton.)
19521950
self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
1953-
self.wrapper_code.src_to_kernel = (
1954-
self.const_module.wrapper_code.src_to_kernel
1955-
)
19561951

19571952
def extract_autotune_inputs(
19581953
self, example_inputs: list[Union[int, float, torch.Tensor]]

0 commit comments

Comments
 (0)