Skip to content

Commit adecb0c

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass-EVT] Fix buffer size issues (pytorch#161335)
Pull Request resolved: pytorch#161335 Approved by: https://github.com/henrylhtsang ghstack dependencies: pytorch#161398
1 parent d57c79e commit adecb0c

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,10 @@ def render( # type: ignore[override]
11681168
op = self.swap_XW(op)
11691169
should_swap_xw = True
11701170

1171+
name_to_buffer = {node.get_name(): node for node in self.input_nodes}
1172+
# handle the fake output buffer during lowering
1173+
name_to_buffer[Y.get_name()] = Y # type: ignore[assignment]
1174+
11711175
if epilogue_nodes or is_scaled_mm:
11721176
if epilogue_nodes:
11731177
(
@@ -1179,12 +1183,15 @@ def render( # type: ignore[override]
11791183
Y.get_name(), epilogue_nodes, V.kernel.removed_buffers
11801184
)
11811185

1186+
# TODO: mlazos remove this by returning buffer metadata from
1187+
# ir_to_evt_python code
1188+
for name, buf in (
1189+
V.graph.name_to_buffer | V.graph.graph_inputs
1190+
).items():
1191+
if name not in name_to_buffer:
1192+
name_to_buffer[name] = buf # type: ignore[assignment]
1193+
11821194
D_output_name = var_name_to_buffer_name["D"]
1183-
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
1184-
for name in V.graph.constants.keys():
1185-
name_to_buffer[name] = V.graph.add_tensor_constant(
1186-
V.graph.constants[name], name
1187-
)
11881195
D_output_buffer = name_to_buffer[D_output_name]
11891196
Y = D_output_buffer # type: ignore[assignment]
11901197
# Interestingly, I don't think the rest of the layout matters here since we
@@ -1229,6 +1236,7 @@ def render( # type: ignore[override]
12291236
op,
12301237
evt_py_code,
12311238
var_name_to_buffer_name,
1239+
name_to_buffer,
12321240
Y.get_dtype(),
12331241
acc_dtype,
12341242
)
@@ -1327,6 +1335,7 @@ def _render_evt(
13271335
op: GemmOperation,
13281336
evt_py_code: str,
13291337
buffer_renames: dict[str, str],
1338+
name_to_buffer: dict[str, Buffer],
13301339
output_dtype: torch.dtype,
13311340
accumulator_dtype: torch.dtype,
13321341
) -> tuple[str, str, str, EVTArgRenames]: # type: ignore[name-defined] # noqa: F821
@@ -1488,23 +1497,15 @@ def _render_evt(
14881497
op: GemmOperation,
14891498
evt_py_code: str,
14901499
var_name_to_buffer_name: dict[str, str],
1500+
name_to_buffer: dict[str, Buffer],
14911501
output_dtype: torch.dtype,
14921502
accumulator_dtype: torch.dtype,
14931503
) -> tuple[str, str, str, EVTArgRenames]:
14941504
from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace
14951505

1496-
name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs
1497-
1498-
for name in V.graph.constants.keys():
1499-
name_to_buffer[name] = V.graph.add_tensor_constant(
1500-
V.graph.constants[name], name
1501-
)
1502-
1503-
# handle the fake output buffer during lowering
1504-
name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment]
1505-
15061506
acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype)
15071507
output_dtype = torch_dtype_to_cutlass_type(output_dtype)
1508+
15081509
examples = create_example_tensors(
15091510
var_name_to_buffer_name,
15101511
name_to_buffer, # type: ignore[arg-type]

0 commit comments

Comments
 (0)