@@ -1168,6 +1168,10 @@ def render( # type: ignore[override]
1168
1168
op = self .swap_XW (op )
1169
1169
should_swap_xw = True
1170
1170
1171
+ name_to_buffer = {node .get_name (): node for node in self .input_nodes }
1172
+ # handle the fake output buffer during lowering
1173
+ name_to_buffer [Y .get_name ()] = Y # type: ignore[assignment]
1174
+
1171
1175
if epilogue_nodes or is_scaled_mm :
1172
1176
if epilogue_nodes :
1173
1177
(
@@ -1179,12 +1183,15 @@ def render( # type: ignore[override]
1179
1183
Y .get_name (), epilogue_nodes , V .kernel .removed_buffers
1180
1184
)
1181
1185
1186
+ # TODO: mlazos remove this by returning buffer metadata from
1187
+ # ir_to_evt_python code
1188
+ for name , buf in (
1189
+ V .graph .name_to_buffer | V .graph .graph_inputs
1190
+ ).items ():
1191
+ if name not in name_to_buffer :
1192
+ name_to_buffer [name ] = buf # type: ignore[assignment]
1193
+
1182
1194
D_output_name = var_name_to_buffer_name ["D" ]
1183
- name_to_buffer = V .graph .name_to_buffer | V .graph .graph_inputs
1184
- for name in V .graph .constants .keys ():
1185
- name_to_buffer [name ] = V .graph .add_tensor_constant (
1186
- V .graph .constants [name ], name
1187
- )
1188
1195
D_output_buffer = name_to_buffer [D_output_name ]
1189
1196
Y = D_output_buffer # type: ignore[assignment]
1190
1197
# Interestingly, I don't think the rest of the layout matters here since we
@@ -1229,6 +1236,7 @@ def render( # type: ignore[override]
1229
1236
op ,
1230
1237
evt_py_code ,
1231
1238
var_name_to_buffer_name ,
1239
+ name_to_buffer ,
1232
1240
Y .get_dtype (),
1233
1241
acc_dtype ,
1234
1242
)
@@ -1327,6 +1335,7 @@ def _render_evt(
1327
1335
op : GemmOperation ,
1328
1336
evt_py_code : str ,
1329
1337
buffer_renames : dict [str , str ],
1338
+ name_to_buffer : dict [str , Buffer ],
1330
1339
output_dtype : torch .dtype ,
1331
1340
accumulator_dtype : torch .dtype ,
1332
1341
) -> tuple [str , str , str , EVTArgRenames ]: # type: ignore[name-defined] # noqa: F821
@@ -1488,23 +1497,15 @@ def _render_evt(
1488
1497
op : GemmOperation ,
1489
1498
evt_py_code : str ,
1490
1499
var_name_to_buffer_name : dict [str , str ],
1500
+ name_to_buffer : dict [str , Buffer ],
1491
1501
output_dtype : torch .dtype ,
1492
1502
accumulator_dtype : torch .dtype ,
1493
1503
) -> tuple [str , str , str , EVTArgRenames ]:
1494
1504
from .cutlass_lib_extensions .evt_extensions import create_example_tensors , trace
1495
1505
1496
- name_to_buffer = V .graph .name_to_buffer | V .graph .graph_inputs
1497
-
1498
- for name in V .graph .constants .keys ():
1499
- name_to_buffer [name ] = V .graph .add_tensor_constant (
1500
- V .graph .constants [name ], name
1501
- )
1502
-
1503
- # handle the fake output buffer during lowering
1504
- name_to_buffer [self .output_node .get_name ()] = self .output_node # type: ignore[assignment]
1505
-
1506
1506
acc_dtype = torch_dtype_to_cutlass_type (accumulator_dtype )
1507
1507
output_dtype = torch_dtype_to_cutlass_type (output_dtype )
1508
+
1508
1509
examples = create_example_tensors (
1509
1510
var_name_to_buffer_name ,
1510
1511
name_to_buffer , # type: ignore[arg-type]
0 commit comments