Skip to content

Commit 63a3074

Browse files
authored
bugfix: ensure graph is captured and executed on the same stream to avoid rep… (#1303)
<!-- .github/pull_request_template.md --> ## 📌 Description Fix cuda graph replay issue when integrating mm_fp4 API. Ensure graph is captured and executed on the same stream. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 04b9a2a commit 63a3074

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

flashinfer/gemm.py

100644100755
Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -822,10 +822,11 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
822822
scale_type,
823823
o_type,
824824
block_size,
825+
device,
825826
):
826827
_check_cudnn_availability()
827-
828-
with cudnn.graph(_get_cudnn_handle(torch.cuda.current_stream())) as (graph, _):
828+
stream = torch.cuda.current_stream(device)
829+
with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _):
829830
a_cudnn_tensor = graph.tensor(
830831
name="a", dim=a_shape, stride=a_stride, data_type=ab_type
831832
)
@@ -911,17 +912,17 @@ def execute_cudnn_gemm_fp4_graph(graph, a, b, a_descale, b_descale, alpha, c_fin
911912
}
912913

913914
workspace = torch.empty(
914-
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
915+
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
915916
)
916917

917-
graph.execute(
918-
variant_pack, workspace, handle=_get_cudnn_handle(torch.cuda.current_stream())
919-
)
918+
stream = torch.cuda.current_stream(a.device)
919+
920+
graph.execute(variant_pack, workspace, handle=_get_cudnn_handle(stream))
920921

921922

922923
@functools.lru_cache(maxsize=128)
923924
def build_cudnn_gemm_with_per_tensor_q_graph(
924-
a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type
925+
a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type, device
925926
):
926927
"""Build a cuDNN graph for GEMM with per-tensor quantization.
927928
@@ -941,7 +942,8 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
941942
"""
942943
_check_cudnn_availability()
943944

944-
with cudnn.graph(_get_cudnn_handle(torch.cuda.current_stream())) as (graph, _):
945+
stream = torch.cuda.current_stream(device)
946+
with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _):
945947

946948
a_cudnn_tensor = graph.tensor(
947949
name="a", dim=a_shape, stride=a_stride, data_type=a_type
@@ -992,10 +994,11 @@ def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
992994
UIDs.O_UID.value: c_final,
993995
}
994996

995-
cudnn_handle = _get_cudnn_handle(torch.cuda.current_stream())
997+
stream = torch.cuda.current_stream(a.device)
998+
cudnn_handle = _get_cudnn_handle(stream)
996999

9971000
workspace = torch.empty(
998-
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
1001+
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
9991002
)
10001003

10011004
graph.execute(variant_pack, workspace, handle=cudnn_handle)
@@ -1036,6 +1039,7 @@ def _cudnn_gemm_fp8(
10361039
_torch_data_type_to_cudnn_data_type(a.dtype),
10371040
_torch_data_type_to_cudnn_data_type(b.dtype),
10381041
_torch_data_type_to_cudnn_data_type(torch_out_dtype),
1042+
a.device,
10391043
)
10401044

10411045
execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out)
@@ -1223,6 +1227,7 @@ def mm_fp4(
12231227
torch.float8_e4m3fn,
12241228
_torch_data_type_to_cudnn_data_type(out_dtype),
12251229
block_size,
1230+
a.device,
12261231
)
12271232

12281233
# execute the fp4 cudnn graph

0 commit comments

Comments
 (0)