Skip to content

Commit e0c7019

Browse files
authored
bugfix: fix perf issue by using fp8 graph that can use cublaslt (#1435)
replace the gemm+pointwise(alpha) with gemm+pointwise(scale_a)+pointwise(scale_b) <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Signed-off-by: Vincent Huang <[email protected]>
1 parent 12dfcc5 commit e0c7019

File tree

1 file changed

+47
-26
lines changed

1 file changed

+47
-26
lines changed

flashinfer/gemm.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,9 @@ class UIDs(Enum):
920920
ALPHA_UID = 2
921921
BLOCK_DESCALE_A_UID = 3
922922
BLOCK_DESCALE_B_UID = 4
923-
O_UID = 5
923+
A_SCALE_UID = 5
924+
B_SCALE_UID = 6
925+
O_UID = 7
924926

925927

926928
def _check_cudnn_availability():
@@ -1118,7 +1120,7 @@ def execute_cudnn_gemm_fp4_graph(
11181120
UIDs.O_UID.value: c_final,
11191121
}
11201122

1121-
if graph.get_workspace_size() > DEFAULT_WORKSPACE_SIZE:
1123+
if workspace_buffer.numel() < graph.get_workspace_size():
11221124
workspace_buffer = torch.empty(
11231125
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
11241126
)
@@ -1158,8 +1160,14 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
11581160
b_cudnn_tensor = graph.tensor(
11591161
name="b", dim=b_shape, stride=b_stride, data_type=b_type
11601162
)
1161-
scale_cudnn_tensor = graph.tensor(
1162-
name="scale",
1163+
a_scale_cudnn_tensor = graph.tensor(
1164+
name="a_scale",
1165+
dim=(1, 1, 1),
1166+
stride=(1, 1, 1),
1167+
data_type=cudnn.data_type.FLOAT,
1168+
)
1169+
b_scale_cudnn_tensor = graph.tensor(
1170+
name="b_scale",
11631171
dim=(1, 1, 1),
11641172
stride=(1, 1, 1),
11651173
data_type=cudnn.data_type.FLOAT,
@@ -1171,18 +1179,28 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
11711179
compute_data_type=cudnn.data_type.FLOAT,
11721180
)
11731181
c_cudnn_tensor.set_name("c").set_data_type(cudnn.data_type.FLOAT)
1174-
c_final_cudnn_tensor = graph.mul(
1175-
name="scale_mul",
1182+
c_after_scale_a_cudnn_tensor = graph.mul(
1183+
name="scale_mul_a",
11761184
a=c_cudnn_tensor,
1177-
b=scale_cudnn_tensor,
1185+
b=a_scale_cudnn_tensor,
11781186
compute_data_type=cudnn.data_type.FLOAT,
11791187
)
1180-
c_final_cudnn_tensor.set_name("c_final").set_output(True).set_data_type(o_type)
1188+
c_after_scale_b_cudnn_tensor = graph.mul(
1189+
name="scale_mul_b",
1190+
a=c_after_scale_a_cudnn_tensor,
1191+
b=b_scale_cudnn_tensor,
1192+
compute_data_type=cudnn.data_type.FLOAT,
1193+
)
1194+
1195+
c_after_scale_b_cudnn_tensor.set_name("c_final").set_output(True).set_data_type(
1196+
o_type
1197+
)
11811198

11821199
a_cudnn_tensor.set_uid(UIDs.A_UID.value)
11831200
b_cudnn_tensor.set_uid(UIDs.B_UID.value)
1184-
scale_cudnn_tensor.set_uid(UIDs.ALPHA_UID.value)
1185-
c_final_cudnn_tensor.set_uid(UIDs.O_UID.value)
1201+
a_scale_cudnn_tensor.set_uid(UIDs.A_SCALE_UID.value)
1202+
b_scale_cudnn_tensor.set_uid(UIDs.B_SCALE_UID.value)
1203+
c_after_scale_b_cudnn_tensor.set_uid(UIDs.O_UID.value)
11861204

11871205
graph.validate()
11881206
graph.build_operation_graph()
@@ -1193,20 +1211,24 @@ def build_cudnn_gemm_with_per_tensor_q_graph(
11931211
return graph
11941212

11951213

1196-
def execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, alpha, c_final):
1214+
def execute_cudnn_gemm_with_per_tensor_q_graph(
1215+
graph, a, b, a_scale, b_scale, c_final, workspace
1216+
):
11971217
variant_pack = {
11981218
UIDs.A_UID.value: a,
11991219
UIDs.B_UID.value: b,
1200-
UIDs.ALPHA_UID.value: alpha,
1220+
UIDs.A_SCALE_UID.value: a_scale,
1221+
UIDs.B_SCALE_UID.value: b_scale,
12011222
UIDs.O_UID.value: c_final,
12021223
}
12031224

12041225
stream = torch.cuda.current_stream(a.device)
12051226
cudnn_handle = _get_cudnn_handle(stream)
12061227

1207-
workspace = torch.empty(
1208-
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
1209-
)
1228+
if workspace.numel() < graph.get_workspace_size():
1229+
workspace = torch.empty(
1230+
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
1231+
)
12101232

12111233
graph.execute(variant_pack, workspace, handle=cudnn_handle)
12121234

@@ -1225,19 +1247,16 @@ def _torch_data_type_to_cudnn_data_type(dtype: torch.dtype):
12251247

12261248

12271249
def _cudnn_gemm_fp8(
1250+
workspace: torch.Tensor,
12281251
a: torch.Tensor,
12291252
b: torch.Tensor,
1230-
dq_scale: torch.Tensor,
1253+
a_scale: torch.Tensor,
1254+
b_scale: torch.Tensor,
12311255
out: Optional[torch.Tensor],
12321256
torch_out_dtype: torch.dtype,
12331257
):
12341258
_check_cudnn_availability()
12351259

1236-
if out is None:
1237-
out = torch.empty(
1238-
a.shape[0], a.shape[1], b.shape[2], dtype=torch_out_dtype, device=a.device
1239-
)
1240-
12411260
graph = build_cudnn_gemm_with_per_tensor_q_graph(
12421261
a.shape,
12431262
a.stride(),
@@ -1249,7 +1268,9 @@ def _cudnn_gemm_fp8(
12491268
a.device,
12501269
)
12511270

1252-
execute_cudnn_gemm_with_per_tensor_q_graph(graph, a, b, dq_scale, out)
1271+
execute_cudnn_gemm_with_per_tensor_q_graph(
1272+
graph, a, b, a_scale, b_scale, out, workspace
1273+
)
12531274
return out
12541275

12551276

@@ -1564,12 +1585,12 @@ def bmm_fp8(
15641585
dtype=dtype,
15651586
)
15661587

1588+
workspace_buffer = _get_cache_buf(
1589+
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
1590+
)
15671591
if backend == "cudnn":
1568-
return _cudnn_gemm_fp8(A, B, A_scale * B_scale, out, dtype)
1592+
return _cudnn_gemm_fp8(workspace_buffer, A, B, A_scale, B_scale, out, dtype)
15691593
elif backend == "cublas":
1570-
workspace_buffer = _get_cache_buf(
1571-
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
1572-
)
15731594
get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale)
15741595
return out
15751596

0 commit comments

Comments
 (0)