@@ -1104,7 +1104,7 @@ def execute_cudnn_gemm_fp4_graph(
1104
1104
UIDs .O_UID .value : c_final ,
1105
1105
}
1106
1106
1107
- if graph . get_workspace_size () > DEFAULT_WORKSPACE_SIZE :
1107
+ if workspace_buffer . numel () < graph . get_workspace_size () :
1108
1108
workspace_buffer = torch .empty (
1109
1109
graph .get_workspace_size (), device = a .device , dtype = torch .uint8
1110
1110
)
@@ -1179,7 +1179,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
1179
1179
return graph
1180
1180
1181
1181
1182
- def execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , alpha , c_final ):
1182
+ def execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , alpha , c_final , workspace ):
1183
1183
variant_pack = {
1184
1184
UIDs .A_UID .value : a ,
1185
1185
UIDs .B_UID .value : b ,
@@ -1190,9 +1190,10 @@ def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
1190
1190
stream = torch .cuda .current_stream (a .device )
1191
1191
cudnn_handle = _get_cudnn_handle (stream )
1192
1192
1193
- workspace = torch .empty (
1194
- graph .get_workspace_size (), device = a .device , dtype = torch .uint8
1195
- )
1193
+ if workspace .numel () < graph .get_workspace_size ():
1194
+ workspace = torch .empty (
1195
+ graph .get_workspace_size (), device = a .device , dtype = torch .uint8
1196
+ )
1196
1197
1197
1198
graph .execute (variant_pack , workspace , handle = cudnn_handle )
1198
1199
@@ -1216,14 +1217,10 @@ def _cudnn_gemm_fp8(
1216
1217
dq_scale : torch .Tensor ,
1217
1218
out : Optional [torch .Tensor ],
1218
1219
torch_out_dtype : torch .dtype ,
1220
+ workspace : torch .Tensor ,
1219
1221
):
1220
1222
_check_cudnn_availability ()
1221
1223
1222
- if out is None :
1223
- out = torch .empty (
1224
- a .shape [0 ], a .shape [1 ], b .shape [2 ], dtype = torch_out_dtype , device = a .device
1225
- )
1226
-
1227
1224
graph = build_cudnn_gemm_with_per_tensor_q_graph (
1228
1225
a .shape ,
1229
1226
a .stride (),
@@ -1235,7 +1232,7 @@ def _cudnn_gemm_fp8(
1235
1232
a .device ,
1236
1233
)
1237
1234
1238
- execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , dq_scale , out )
1235
+ execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , dq_scale , out , workspace )
1239
1236
return out
1240
1237
1241
1238
@@ -1550,12 +1547,12 @@ def bmm_fp8(
1550
1547
dtype = dtype ,
1551
1548
)
1552
1549
1550
+ workspace_buffer = _get_cache_buf (
1551
+ "bmm_fp8_workspace" , DEFAULT_WORKSPACE_SIZE , A .device
1552
+ )
1553
1553
if backend == "cudnn" :
1554
- return _cudnn_gemm_fp8 (A , B , A_scale * B_scale , out , dtype )
1554
+ return _cudnn_gemm_fp8 (A , B , A_scale * B_scale , out , dtype , workspace_buffer )
1555
1555
elif backend == "cublas" :
1556
- workspace_buffer = _get_cache_buf (
1557
- "bmm_fp8_workspace" , DEFAULT_WORKSPACE_SIZE , A .device
1558
- )
1559
1556
get_gemm_module ().bmm_fp8 (workspace_buffer , A , B , out , A_scale , B_scale )
1560
1557
return out
1561
1558
0 commit comments