Skip to content

Commit 5f0e3b4

Browse files
authored
Merge branch 'ROCm:main' into main
2 parents b645bcb + 9b542ad commit 5f0e3b4

File tree

76 files changed

+2213
-904
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2213
-904
lines changed

.github/workflows/aiter-test.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ concurrency:
1212
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
1313

1414
env:
15-
DOCKER_IMAGE: "rocm/pytorch:latest"
15+
# TODO: Revert to rocm/pytorch:latest once CK adds ROCm 7.2 support
16+
DOCKER_IMAGE: "rocm/pytorch:latest@sha256:683765a52c61341e1674fe730ab3be861a444a45a36c0a8caae7653a08a0e208"
1617

1718
jobs:
1819
check-signal:

3rdparty/composable_kernel

Submodule composable_kernel updated 218 files

aiter/fused_moe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,11 +1643,13 @@ def fused_topk(
16431643
M, topk, dtype=dtypes.i32, device=hidden_states.device
16441644
)
16451645

1646-
if (
1647-
get_gfx() == "gfx942"
1648-
and (expert, topk) in [(128, 6), (128, 8), (256, 6), (256, 8)]
1649-
and gating_output.dtype == dtypes.fp32
1650-
):
1646+
if (expert, topk) in [
1647+
(128, 4),
1648+
(128, 6),
1649+
(128, 8),
1650+
(256, 6),
1651+
(256, 8),
1652+
] and gating_output.dtype in [dtypes.bf16, dtypes.fp32]:
16511653
if topk_weights is None:
16521654
topk_weights = torch.empty(
16531655
(M + 3) // 4 * 4, topk, dtype=dtypes.fp32, device=hidden_states.device

aiter/fused_moe_dp_shared_expert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
2+
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
33

44
import torch
55
import os
@@ -500,6 +500,7 @@ def FinalFunc():
500500
kernelName=kernelName1,
501501
activation=activation,
502502
quant_type=q_type,
503+
dst_type=dtype,
503504
),
504505
functools.partial(
505506
aiter.ck_moe_stage2_fwd,

aiter/jit/optCompilerConfig.json

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@
443443
"verbose": "False",
444444
"blob_gen_cmd": [
445445
"f'{AITER_META_DIR}/hsa/codegen.py -m fmoe_2stages --output_dir {{}}'",
446-
"f'{AITER_META_DIR}/hsa/codegen.py -m fmoe --output_dir {{}}'"
446+
"f'{AITER_META_DIR}/hsa/codegen.py -m fmoe --output_dir {{}}'",
447+
"f'{AITER_META_DIR}/hsa/codegen.py -m topksoftmax --output_dir {{}}'"
447448
]
448449
},
449450
"module_moe_ck2stages": {
@@ -563,6 +564,21 @@
563564
"verbose": "False",
564565
"blob_gen_cmd": "f'{CK_DIR}/example/ck_tile/10_rmsnorm2d/generate.py --api fwd --gen_blobs --working_path {{}}'"
565566
},
567+
"module_rmsnorm_quant": {
568+
"srcs": [
569+
"f'{AITER_CSRC_DIR}/kernels/rmsnorm_quant_kernels.cu'",
570+
"f'{AITER_CSRC_DIR}/pybind/rmsnorm_quant_pybind.cu'"
571+
],
572+
"flags_extra_cc": [],
573+
"flags_extra_hip": ["'-ffast-math'"],
574+
"extra_ldflags": "None",
575+
"extra_include": [
576+
"f'{AITER_CSRC_DIR}/include/ck_tile'",
577+
"f'{AITER_CSRC_DIR}/include/opus'"
578+
],
579+
"verbose": "False",
580+
"blob_gen_cmd": "''"
581+
},
566582
"module_smoothquant": {
567583
"srcs": [
568584
"f'{AITER_CSRC_DIR}/py_itfs_ck/smoothquant_kernels.cu'",

aiter/ops/moe_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def ck_moe_stage1_fwd(
553553
activation.value,
554554
int(splitk) if splitk is not None else splitk,
555555
use_non_temporal_load,
556-
dtype2str_dict[dst_type],
556+
None if dst_type is None else dtype2str_dict[dst_type],
557557
)
558558
return out
559559

aiter/ops/rmsnorm.py

Lines changed: 152 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
2+
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
33

44
import torch
55
from torch import Tensor
@@ -59,16 +59,20 @@ def rms_norm(
5959
...
6060

6161

62-
@compile_ops("module_rmsnorm", gen_fake=gen_rms_norm_fake_tensor)
6362
def rmsnorm2d_fwd(
6463
input: torch.Tensor,
6564
weight: torch.Tensor,
6665
epsilon: float,
6766
use_model_sensitive_rmsnorm: int = 0,
68-
) -> Tensor: ...
67+
) -> Tensor:
68+
out = torch.empty_like(input, dtype=input.dtype, device=input.device)
69+
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
70+
rmsnorm2d_fwd_ck(out, input, weight, epsilon, use_model_sensitive_rmsnorm)
71+
else:
72+
rmsnorm(out, input, weight, epsilon)
73+
return out
6974

7075

71-
@compile_ops("module_rmsnorm")
7276
def rmsnorm2d_fwd_with_add(
7377
out: Tensor,
7478
input: Tensor,
@@ -77,7 +81,19 @@ def rmsnorm2d_fwd_with_add(
7781
weight: Tensor,
7882
epsilon: float,
7983
use_model_sensitive_rmsnorm: int = 0,
80-
) -> None: ...
84+
) -> None:
85+
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
86+
rmsnorm2d_fwd_with_add_ck(
87+
out,
88+
input,
89+
residual_in,
90+
residual_out,
91+
weight,
92+
epsilon,
93+
use_model_sensitive_rmsnorm,
94+
)
95+
else:
96+
add_rmsnorm(out, input, residual_in, residual_out, weight, epsilon)
8197

8298

8399
@compile_ops("module_rmsnorm")
@@ -107,18 +123,26 @@ def rmsnorm2d_fwd_with_add_smoothquant(
107123
) -> None: ...
108124

109125

110-
@compile_ops("module_rmsnorm")
111126
def rmsnorm2d_fwd_with_dynamicquant(
112127
out: Tensor,
113128
input: Tensor,
114129
yscale: Tensor,
115130
weight: Tensor,
116131
epsilon: float,
117132
use_model_sensitive_rmsnorm: int = 0,
118-
) -> None: ...
133+
group_size: int = 0,
134+
shuffle_scale: bool = False,
135+
) -> None:
136+
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
137+
assert group_size == 0, "group_size is not supported for ck rmsnorm"
138+
assert not shuffle_scale, "shuffle_scale is not supported for ck rmsnorm"
139+
rmsnorm2d_fwd_with_dynamicquant_ck(
140+
out, input, yscale, weight, epsilon, use_model_sensitive_rmsnorm
141+
)
142+
else:
143+
rmsnorm_quant(out, input, yscale, weight, epsilon, group_size, shuffle_scale)
119144

120145

121-
@compile_ops("module_rmsnorm")
122146
def rmsnorm2d_fwd_with_add_dynamicquant(
123147
out: Tensor,
124148
input: Tensor,
@@ -128,4 +152,124 @@ def rmsnorm2d_fwd_with_add_dynamicquant(
128152
weight: Tensor,
129153
epsilon: float,
130154
use_model_sensitive_rmsnorm: int = 0,
155+
group_size: int = 0,
156+
shuffle_scale: bool = False,
157+
) -> None:
158+
if use_model_sensitive_rmsnorm > 0 or input.shape[-1] > 8192:
159+
assert group_size == 0, "group_size is not supported for ck rmsnorm"
160+
assert not shuffle_scale, "shuffle_scale is not supported for ck rmsnorm"
161+
rmsnorm2d_fwd_with_add_dynamicquant_ck(
162+
out,
163+
input,
164+
residual_in,
165+
residual_out,
166+
yscale,
167+
weight,
168+
epsilon,
169+
use_model_sensitive_rmsnorm,
170+
)
171+
else:
172+
add_rmsnorm_quant(
173+
out,
174+
input,
175+
residual_in,
176+
residual_out,
177+
yscale,
178+
weight,
179+
epsilon,
180+
group_size,
181+
shuffle_scale,
182+
)
183+
184+
185+
@compile_ops(
186+
"module_rmsnorm", gen_fake=gen_rms_norm_fake_tensor, fc_name="rmsnorm2d_fwd"
187+
)
188+
def rmsnorm2d_fwd_ck(
189+
input: torch.Tensor,
190+
weight: torch.Tensor,
191+
epsilon: float,
192+
use_model_sensitive_rmsnorm: int = 0,
193+
) -> Tensor: ...
194+
195+
196+
@compile_ops("module_rmsnorm", fc_name="rmsnorm2d_fwd_with_add")
197+
def rmsnorm2d_fwd_with_add_ck(
198+
out: Tensor,
199+
input: Tensor,
200+
residual_in: Tensor,
201+
residual_out: Tensor,
202+
weight: Tensor,
203+
epsilon: float,
204+
use_model_sensitive_rmsnorm: int = 0,
205+
) -> None: ...
206+
207+
208+
@compile_ops("module_rmsnorm", fc_name="rmsnorm2d_fwd_with_dynamicquant")
209+
def rmsnorm2d_fwd_with_dynamicquant_ck(
210+
out: Tensor,
211+
input: Tensor,
212+
yscale: Tensor,
213+
weight: Tensor,
214+
epsilon: float,
215+
use_model_sensitive_rmsnorm: int = 0,
216+
) -> None: ...
217+
218+
219+
@compile_ops("module_rmsnorm", fc_name="rmsnorm2d_fwd_with_add_dynamicquant")
220+
def rmsnorm2d_fwd_with_add_dynamicquant_ck(
221+
out: Tensor,
222+
input: Tensor,
223+
residual_in: Tensor,
224+
residual_out: Tensor,
225+
yscale: Tensor,
226+
weight: Tensor,
227+
epsilon: float,
228+
use_model_sensitive_rmsnorm: int = 0,
229+
) -> None: ...
230+
231+
232+
@compile_ops("module_rmsnorm_quant")
233+
def add_rmsnorm_quant(
234+
out: Tensor,
235+
input: Tensor,
236+
residual_in: Tensor,
237+
residual_out: Tensor,
238+
scale: Tensor,
239+
weight: Tensor,
240+
epsilon: float,
241+
group_size: int = 0,
242+
shuffle_scale: bool = False,
243+
) -> None: ...
244+
245+
246+
@compile_ops("module_rmsnorm_quant")
247+
def add_rmsnorm(
248+
out: Tensor,
249+
input: Tensor,
250+
residual_in: Tensor,
251+
residual_out: Tensor,
252+
weight: Tensor,
253+
epsilon: float,
254+
) -> None: ...
255+
256+
257+
@compile_ops("module_rmsnorm_quant")
258+
def rmsnorm_quant(
259+
out: Tensor,
260+
input: Tensor,
261+
scale: Tensor,
262+
weight: Tensor,
263+
epsilon: float,
264+
group_size: int = 0,
265+
shuffle_scale: bool = False,
266+
) -> None: ...
267+
268+
269+
@compile_ops("module_rmsnorm_quant")
270+
def rmsnorm(
271+
out: Tensor,
272+
input: Tensor,
273+
weight: Tensor,
274+
epsilon: float,
131275
) -> None: ...

aiter/ops/triton/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
"moe_op_silu_fused": "moe.moe_op_silu_fused",
114114
"moe_op": "moe.moe_op",
115115
"moe_routing_sigmoid_top1_fused": "moe.moe_routing_sigmoid_top1_fused",
116+
"moe_routing": "moe.moe_routing",
116117
"quant_moe": "moe.quant_moe",
117118
# Normalization modules (normalization/)
118119
"fused_add_rmsnorm_pad": "normalization.fused_add_rmsnorm_pad",

aiter/ops/triton/gluon/pa_decode_gluon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import aiter
77
import aiter.ops.triton.utils._triton.arch_info as arch_info
8-
from aiter.ops.triton.utils.types import torch_to_triton_dtype
98

109
import triton
1110
import triton.language as tl
@@ -3365,6 +3364,8 @@ def pa_decode_gluon(
33653364
raise RuntimeError(
33663365
"This version triton is not support gluon jit mode, please upgrade to 3.5.0 or higher!"
33673366
)
3367+
from aiter.ops.triton.utils.types import torch_to_triton_dtype
3368+
33683369
cdna_version = get_cdna_version()
33693370
assert cdna_version in [
33703371
3,

aiter/utility/base_tuner.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
2+
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
33

44
import os
5+
import sys
56
import argparse
67
import torch
78
import pandas as pd
@@ -116,9 +117,10 @@ def _setup_common_arguments(self):
116117
)
117118
self.parser.add_argument(
118119
"--sort",
119-
action="store_true",
120+
type=dtypes.str2bool,
121+
default=defaults.get("sort", False),
120122
required=False,
121-
help="Arranged according to the keys",
123+
help="Arranged according to the keys (True/False)",
122124
)
123125
self.parser.add_argument(
124126
"--errRatio",
@@ -410,6 +412,11 @@ def tune_summary(self, status):
410412
if not self.remain_untuned.empty:
411413
logger.info("untuned shapes:")
412414
print(self.remain_untuned)
415+
if not self.remain_untuned.empty or not self.failed.empty:
416+
logger.error(
417+
"\033[91m[Tuning not Finished]\033[0m some shapes are not tuned or all failed, please check the result file or tune with --profile_file to get more details"
418+
)
419+
sys.exit(1)
413420

414421
@abstractmethod
415422
def result_to_csv(self, results, file, concat=False):
@@ -480,6 +487,11 @@ def run(self, args, fast_mode=False):
480487

481488
class GemmCommonTuner(TunerCommon):
482489

490+
ARG_DEFAULTS = {
491+
**TunerCommon.ARG_DEFAULTS,
492+
"sort": True, # Enable sorting by default for GEMM tuners
493+
}
494+
483495
def __init__(
484496
self,
485497
name,

0 commit comments

Comments
 (0)