Skip to content

Commit 4276a8e

Browse files
liyuhannnnncodego7250
authored andcommitted
[TRTLLM-6222][feat] Several perf opt for cuteDSL nvf4 gemm (NVIDIA#9428)
Signed-off-by: Yuhan Li <51736452+liyuhannnnn@users.noreply.github.com>
1 parent 627601f commit 4276a8e

File tree

4 files changed

+1150
-501
lines changed

4 files changed

+1150
-501
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ def get_valid_tactics(
8080

8181
# full shamoo
8282
mma_tiler_mn_candidates = [
83-
(256, 128),
83+
(128, 64),
84+
(256, 64),
8485
(128, 128),
86+
(256, 128),
87+
(128, 192),
88+
(256, 192),
8589
(128, 256),
8690
(256, 256),
87-
(256, 64),
88-
(128, 64),
8991
]
9092
cluster_shape_mn_candidates = [
9193
(1, 1),
@@ -99,37 +101,38 @@ def get_valid_tactics(
99101
(4, 4),
100102
]
101103
swap_ab_candidates = [True, False]
104+
use_prefetch_candidates = [True, False]
102105

103106
valid_tactics = []
104-
for swap_ab in swap_ab_candidates:
105-
for mma_tiler_mn in mma_tiler_mn_candidates:
106-
for cluster_shape_mn in cluster_shape_mn_candidates:
107-
if swap_ab:
108-
c_major = "m"
109-
kernel_m = n
110-
kernel_n = m
111-
else:
112-
c_major = "n"
113-
kernel_m = m
114-
kernel_n = n
115-
116-
if self.__class__.kernel_class.can_implement(
117-
cutlass.Float4E2M1FN, # ab_dtype,
118-
cutlass.Float8E4M3FN, # sf_dtype
119-
sf_vec_size, # sf_vec_size,
120-
cutlass.BFloat16, # c_dtype,
121-
mma_tiler_mn,
122-
cluster_shape_mn,
123-
kernel_m,
124-
kernel_n,
125-
real_k,
126-
batch_size,
127-
a_major,
128-
b_major,
129-
c_major,
130-
):
131-
valid_tactics.append(
132-
(mma_tiler_mn, cluster_shape_mn, swap_ab))
107+
for mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch in itertools.product(
108+
mma_tiler_mn_candidates, cluster_shape_mn_candidates,
109+
swap_ab_candidates, use_prefetch_candidates):
110+
if swap_ab:
111+
c_major = "m"
112+
kernel_m = n
113+
kernel_n = m
114+
else:
115+
c_major = "n"
116+
kernel_m = m
117+
kernel_n = n
118+
119+
if self.__class__.kernel_class.can_implement(
120+
cutlass.Float4E2M1FN, # ab_dtype,
121+
cutlass.Float8E4M3FN, # sf_dtype
122+
sf_vec_size, # sf_vec_size,
123+
cutlass.BFloat16, # c_dtype,
124+
mma_tiler_mn,
125+
cluster_shape_mn,
126+
kernel_m,
127+
kernel_n,
128+
real_k,
129+
batch_size,
130+
a_major,
131+
b_major,
132+
c_major,
133+
):
134+
valid_tactics.append(
135+
(mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch))
133136

134137
return valid_tactics
135138

@@ -158,21 +161,22 @@ def forward(
158161
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
159162
inputs[4]: Alpha scaling factor. dtype: float32.
160163
inputs[5]: Output dtype, expected to be torch.bfloat16.
161-
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
164+
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch).
162165
163166
Returns:
164167
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
165168
"""
166169
sf_vec_size = 16
167170

168171
if isinstance(tactic, tuple):
169-
mma_tiler_mn, cluster_shape_mn, swap_ab = tactic
172+
mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch = tactic
170173
else:
171174
# fallback to default tactic
172-
mma_tiler_mn, cluster_shape_mn, swap_ab = [
175+
mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch = [
173176
(128, 128),
174177
(1, 1),
175178
False,
179+
False,
176180
]
177181

178182
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
@@ -208,7 +212,8 @@ def forward(
208212
torch_stream = torch.cuda.current_stream()
209213
stream = cuda.CUstream(torch_stream.cuda_stream)
210214

211-
cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab)
215+
cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab,
216+
use_prefetch)
212217
if swap_ab:
213218
kernel_a_ptr = b_ptr
214219
kernel_a_sf_ptr = b_sf_ptr
@@ -233,6 +238,7 @@ def forward(
233238
sf_vec_size,
234239
mma_tiler_mn,
235240
cluster_shape_mn,
241+
use_prefetch,
236242
)
237243
# Compute max active clusters on current device
238244
hardware_info = cutlass.utils.HardwareInfo()
@@ -257,6 +263,7 @@ def forward(
257263
max_active_clusters,
258264
stream,
259265
swap_ab,
266+
options=f"--opt-level 2",
260267
)
261268

262269
self.__class__.kernel_cache[cache_key] = compiled_gemm

0 commit comments

Comments
 (0)