Skip to content

Commit 5a8e005

Browse files
committed
refactor: reuse workspace for bmm_fp8
1 parent a6a1e49 commit 5a8e005

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

flashinfer/gemm.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
11791179
return graph
11801180

11811181

1182-
def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
1182+
def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final, workspace):
11831183
variant_pack = {
11841184
UIDs.A_UID.value: a,
11851185
UIDs.B_UID.value: b,
@@ -1190,9 +1190,10 @@ def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
11901190
stream = torch.cuda.current_stream(a.device)
11911191
cudnn_handle = _get_cudnn_handle(stream)
11921192

1193-
workspace = torch.empty(
1194-
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
1195-
)
1193+
if graph.get_workspace_size() > DEFAULT_WORKSPACE_SIZE:
1194+
workspace = torch.empty(
1195+
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
1196+
)
11961197

11971198
graph.execute(variant_pack, workspace, handle=cudnn_handle)
11981199

@@ -1216,6 +1217,7 @@ def _cudnn_gemm_fp8(
12161217
dq_scale: torch.Tensor,
12171218
out: Optional[torch.Tensor],
12181219
torch_out_dtype: torch.dtype,
1220+
workspace: torch.Tensor,
12191221
):
12201222
_check_cudnn_availability()
12211223

@@ -1235,7 +1237,7 @@ def _cudnn_gemm_fp8(
12351237
a.device,
12361238
)
12371239

1238-
execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out)
1240+
execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out, workspace)
12391241
return out
12401242

12411243

@@ -1550,12 +1552,12 @@ def bmm_fp8(
15501552
dtype=dtype,
15511553
)
15521554

1555+
workspace_buffer = _get_cache_buf(
1556+
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
1557+
)
15531558
if backend == "cudnn":
1554-
return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype)
1559+
return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype, workspace_buffer)
15551560
elif backend == "cublas":
1556-
workspace_buffer = _get_cache_buf(
1557-
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
1558-
)
15591561
get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale)
15601562
return out
15611563

0 commit comments

Comments
 (0)