Skip to content

Commit db1983e

Browse files
committed
refactor: reuse workspace for bmm_fp8
1 parent a6a1e49 commit db1983e

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

flashinfer/gemm.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,7 @@ def execute_cudnn_gemm_fp4_graph(
11041104
UIDs.O_UID.value: c_final,
11051105
}
11061106

1107-
if graph.get_workspace_size() > DEFAULT_WORKSPACE_SIZE:
1107+
if workspace_buffer.numel() < graph.get_workspace_size():
11081108
workspace_buffer = torch.empty(
11091109
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
11101110
)
@@ -1179,7 +1179,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
11791179
return graph
11801180

11811181

1182-
def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
1182+
def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final, workspace):
11831183
variant_pack = {
11841184
UIDs.A_UID.value: a,
11851185
UIDs.B_UID.value: b,
@@ -1190,9 +1190,10 @@ def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
11901190
stream = torch.cuda.current_stream(a.device)
11911191
cudnn_handle = _get_cudnn_handle(stream)
11921192

1193-
workspace = torch.empty(
1194-
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
1195-
)
1193+
if workspace.numel() < graph.get_workspace_size():
1194+
workspace = torch.empty(
1195+
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
1196+
)
11961197

11971198
graph.execute(variant_pack, workspace, handle=cudnn_handle)
11981199

@@ -1216,14 +1217,10 @@ def _cudnn_gemm_fp8(
12161217
dq_scale: torch.Tensor,
12171218
out: Optional[torch.Tensor],
12181219
torch_out_dtype: torch.dtype,
1220+
workspace: torch.Tensor,
12191221
):
12201222
_check_cudnn_availability()
12211223

1222-
if out is None:
1223-
out = torch.empty(
1224-
a.shape[0], a.shape[1], b.shape[2], dtype=torch_out_dtype, device=a.device
1225-
)
1226-
12271224
graph = build_cudnn_gemm_with_per_tensor_q_graph(
12281225
a.shape,
12291226
a.stride(),
@@ -1235,7 +1232,7 @@ def _cudnn_gemm_fp8(
12351232
a.device,
12361233
)
12371234

1238-
execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out)
1235+
execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out, workspace)
12391236
return out
12401237

12411238

@@ -1550,12 +1547,12 @@ def bmm_fp8(
15501547
dtype=dtype,
15511548
)
15521549

1550+
workspace_buffer = _get_cache_buf(
1551+
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
1552+
)
15531553
if backend == "cudnn":
1554-
return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype)
1554+
return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype, workspace_buffer)
15551555
elif backend == "cublas":
1556-
workspace_buffer = _get_cache_buf(
1557-
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
1558-
)
15591556
get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale)
15601557
return out
15611558

0 commit comments

Comments
 (0)