@@ -1179,7 +1179,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
1179
1179
return graph
1180
1180
1181
1181
1182
- def execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , alpha , c_final ):
1182
+ def execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , alpha , c_final , workspace ):
1183
1183
variant_pack = {
1184
1184
UIDs .A_UID .value : a ,
1185
1185
UIDs .B_UID .value : b ,
@@ -1190,9 +1190,10 @@ def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
1190
1190
stream = torch .cuda .current_stream (a .device )
1191
1191
cudnn_handle = _get_cudnn_handle (stream )
1192
1192
1193
- workspace = torch .empty (
1194
- graph .get_workspace_size (), device = a .device , dtype = torch .uint8
1195
- )
1193
+ if graph .get_workspace_size () > DEFAULT_WORKSPACE_SIZE :
1194
+ workspace = torch .empty (
1195
+ graph .get_workspace_size (), device = a .device , dtype = torch .uint8
1196
+ )
1196
1197
1197
1198
graph .execute (variant_pack , workspace , handle = cudnn_handle )
1198
1199
@@ -1216,6 +1217,7 @@ def _cudnn_gemm_fp8(
1216
1217
dq_scale : torch .Tensor ,
1217
1218
out : Optional [torch .Tensor ],
1218
1219
torch_out_dtype : torch .dtype ,
1220
+ workspace : torch .Tensor ,
1219
1221
):
1220
1222
_check_cudnn_availability ()
1221
1223
@@ -1235,7 +1237,7 @@ def _cudnn_gemm_fp8(
1235
1237
a .device ,
1236
1238
)
1237
1239
1238
- execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , dq_scale , out )
1240
+ execute_cudnn_gemm_with_per_tensor_q_graph (graph , a , b , dq_scale , out , workspace )
1239
1241
return out
1240
1242
1241
1243
@@ -1550,12 +1552,12 @@ def bmm_fp8(
1550
1552
dtype = dtype ,
1551
1553
)
1552
1554
1555
+ workspace_buffer = _get_cache_buf (
1556
+ "bmm_fp8_workspace" , DEFAULT_WORKSPACE_SIZE , A .device
1557
+ )
1553
1558
if backend == "cudnn" :
1554
- return _cudnn_gemm_fp8 (A , B , A_scale * B_scale , out , dtype )
1559
+ return _cudnn_gemm_fp8 (A , B , A_scale * B_scale , out , dtype , workspace_buffer )
1555
1560
elif backend == "cublas" :
1556
- workspace_buffer = _get_cache_buf (
1557
- "bmm_fp8_workspace" , DEFAULT_WORKSPACE_SIZE , A .device
1558
- )
1559
1561
get_gemm_module ().bmm_fp8 (workspace_buffer , A , B , out , A_scale , B_scale )
1560
1562
return out
1561
1563
0 commit comments