diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 665e32a5f..36ac1ed78 100755 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -1104,7 +1104,7 @@ def execute_cudnn_gemm_fp4_graph( UIDs.O_UID.value: c_final, } - if graph.get_workspace_size() > DEFAULT_WORKSPACE_SIZE: + if workspace_buffer.numel() < graph.get_workspace_size(): workspace_buffer = torch.empty( graph.get_workspace_size(), device=a.device, dtype=torch.uint8 ) @@ -1179,7 +1179,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph( return graph -def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final): +def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final, workspace): variant_pack = { UIDs.A_UID.value: a, UIDs.B_UID.value: b, @@ -1190,9 +1190,10 @@ def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final): stream = torch.cuda.current_stream(a.device) cudnn_handle = _get_cudnn_handle(stream) - workspace = torch.empty( - graph.get_workspace_size(), device=a.device, dtype=torch.uint8 - ) + if workspace.numel() < graph.get_workspace_size(): + workspace = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) graph.execute(variant_pack, workspace, handle=cudnn_handle) @@ -1216,14 +1217,10 @@ def _cudnn_gemm_fp8( dq_scale: torch.Tensor, out: Optional[torch.Tensor], torch_out_dtype: torch.dtype, + workspace: torch.Tensor, ): _check_cudnn_availability() - if out is None: - out = torch.empty( - a.shape[0], a.shape[1], b.shape[2], dtype=torch_out_dtype, device=a.device - ) - graph = build_cudnn_gemm_with_per_tensor_q_graph( a.shape, a.stride(), @@ -1235,7 +1232,7 @@ def _cudnn_gemm_fp8( a.device, ) - execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out) + execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out, workspace) return out @@ -1550,12 +1547,12 @@ def bmm_fp8( dtype=dtype, ) + workspace_buffer = _get_cache_buf( + "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device + ) if backend == "cudnn": - return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype) + return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype, workspace_buffer) elif backend == "cublas": - workspace_buffer = _get_cache_buf( - "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device - ) get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale) return out