Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def execute_cudnn_gemm_fp4_graph(
UIDs.O_UID.value: c_final,
}

if graph.get_workspace_size() > DEFAULT_WORKSPACE_SIZE:
if workspace_buffer.numel() < graph.get_workspace_size():
workspace_buffer = torch.empty(
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
)
Expand Down Expand Up @@ -1179,7 +1179,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
return graph


def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final, workspace):
variant_pack = {
UIDs.A_UID.value: a,
UIDs.B_UID.value: b,
Expand All @@ -1190,9 +1190,10 @@ def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
stream = torch.cuda.current_stream(a.device)
cudnn_handle = _get_cudnn_handle(stream)

workspace = torch.empty(
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
)
if workspace.numel() < graph.get_workspace_size():
workspace = torch.empty(
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
)

graph.execute(variant_pack, workspace, handle=cudnn_handle)

Expand All @@ -1216,14 +1217,10 @@ def _cudnn_gemm_fp8(
dq_scale: torch.Tensor,
out: Optional[torch.Tensor],
torch_out_dtype: torch.dtype,
workspace: torch.Tensor,
):
_check_cudnn_availability()

if out is None:
out = torch.empty(
a.shape[0], a.shape[1], b.shape[2], dtype=torch_out_dtype, device=a.device
)

graph = build_cudnn_gemm_with_per_tensor_q_graph(
a.shape,
a.stride(),
Expand All @@ -1235,7 +1232,7 @@ def _cudnn_gemm_fp8(
a.device,
)

execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out)
execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out, workspace)
return out


Expand Down Expand Up @@ -1550,12 +1547,12 @@ def bmm_fp8(
dtype=dtype,
)

workspace_buffer = _get_cache_buf(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not separating the cache of different backends and allocate cudnn workspace inside execute_cudnn_gemm_with_per_tensor_q_graph after graph.get_workspace_size() is calculated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that all the methods in gemm.py use DEFAULT_WORKSPACE_SIZE. I assumed the design choice was because different tactics might require different workspace sizes, and using a sufficiently large default helps avoid frequent memory reallocations when the workspace changes. If that’s incorrect, I can update all the methods to request a workspace size matching the current tactic’s requirement instead. Thanks!

"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
)
if backend == "cudnn":
return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype)
return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype, workspace_buffer)
elif backend == "cublas":
workspace_buffer = _get_cache_buf(
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
)
get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale)
return out

Expand Down