Skip to content

Commit 1c564c2

Browse files
authored
changed test_silu_mul_quant_fuse and test_rsnorm_quant_fuse to use pytorch kernels in reference function instead of aiter non-triton kernels (ROCm#1891)
1 parent ebd0485 commit 1c564c2

File tree

2 files changed

+21
-25
lines changed

2 files changed

+21
-25
lines changed

aiter/utility/dtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
defaultDtypes = {
99
"gfx942": {"fp8": torch.float8_e4m3fnuz},
1010
"gfx950": {"fp8": torch.float8_e4m3fn},
11+
"gfx1250": {"fp8": torch.float8_e4m3fn},
1112
}
1213

1314
_8bit_fallback = torch.uint8

op_tests/triton_tests/quant/test_fused_fp8_quant.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,9 @@
99
fused_silu_mul_fp8_per_tensor_static_quant,
1010
)
1111

12-
from aiter import (
13-
rmsnorm2d_fwd,
14-
silu_and_mul,
15-
)
16-
1712
from aiter.test_common import (
1813
checkAllclose,
1914
)
20-
21-
from aiter.ops.quant import per_tensor_quant_hip
2215
import aiter
2316
import torch.nn.functional as F
2417

@@ -198,17 +191,11 @@ def test_fused_rms_fp8_group_quant(M: int, N1: int, N2: int, dtype):
198191
torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1)
199192

200193

201-
def rmsnorm2d_fwd_(
202-
x: torch.Tensor, weight: torch.Tensor, eps: float, dim: int
203-
) -> torch.Tensor:
204-
ori_shape = x.shape
205-
x = x.reshape(-1, dim)
206-
return rmsnorm2d_fwd(x, weight, eps).view(ori_shape)
207-
208-
209-
def rmsnorm_fp8_quantization_ref(x, w, x_scale, eps, n, rocm_fp8_dtype):
210-
rms_out = rmsnorm2d_fwd_(x, w, eps, n)
211-
quant_out, _ = per_tensor_quant_hip(rms_out, x_scale, rocm_fp8_dtype)
194+
def rmsnorm_fp8_quantization_ref(x, w, x_scale, eps, rocm_fp8_dtype):
195+
rms_out = rmsnorm(x.to(torch.float32), w.to(torch.float32), eps).to(x.dtype)
196+
quant_out = per_tensor_fp8_static_quant(
197+
rms_out.to(torch.float32), rocm_fp8_dtype, x_scale.to(torch.float32)
198+
)
212199
return quant_out, rms_out
213200

214201

@@ -248,14 +235,14 @@ def test_rmsnorm_quant_fuse(m, n):
248235
)
249236

250237
# calculate the correct scale value
251-
rms_out = rmsnorm2d_fwd_(x, w, eps, n)
238+
rms_out = rmsnorm(x.to(torch.float32), w.to(torch.float32), eps)
252239
rms_out_abs = torch.abs(rms_out)
253240
rms_out_abs_max = torch.max(rms_out_abs)
254241
scale_val = rms_out_abs_max / DTYPE_MAX
255242
x_scale = torch.tensor((scale_val), dtype=torch.float32, device="cuda")
256243

257244
fp8_x_ref, rms_out_ref = rmsnorm_fp8_quantization_ref(
258-
x, w, x_scale, eps, n, rocm_fp8_dtype
245+
x, w, x_scale, eps, rocm_fp8_dtype
259246
)
260247
fp8_x, rms_out = triton_rmsnorm_fp8_quantization_fuse(
261248
x, w, x_scale, eps, rocm_fp8_dtype
@@ -658,9 +645,13 @@ def silu_mul_fp8_quantization_ref(x, x_scale, rocm_fp8_dtype):
658645
m, n2 = x.shape
659646
assert n2 % 2 == 0
660647
n = n2 // 2
661-
silu_out = torch.empty((m, n), dtype=x.dtype, device=x.device)
662-
silu_and_mul(silu_out, x)
663-
quant_out, _ = per_tensor_quant_hip(silu_out, x_scale, rocm_fp8_dtype)
648+
x1, x2 = x.split([n, n], dim=-1)
649+
silu_out = (
650+
(F.silu(x1.to(torch.float32)) * x2.to(torch.float32))
651+
.to(x.dtype)
652+
.to(torch.float32)
653+
)
654+
quant_out = per_tensor_fp8_static_quant(silu_out, rocm_fp8_dtype, x_scale)
664655
return quant_out
665656

666657

@@ -688,8 +679,12 @@ def test_silu_mul_quant_fuse(m, n):
688679
)
689680

690681
# calculate the correct scale value
691-
silu_out = torch.empty((m, n), dtype=x.dtype, device=x.device)
692-
silu_and_mul(silu_out, x)
682+
x1, x2 = x.split([n, n], dim=-1)
683+
silu_out = (
684+
(F.silu(x1.to(torch.float32)) * x2.to(torch.float32))
685+
.to(x.dtype)
686+
.to(torch.float32)
687+
)
693688
silu_out_abs = torch.abs(silu_out)
694689
silu_out_abs_max = torch.max(silu_out_abs)
695690
scale_val = silu_out_abs_max / DTYPE_MAX

0 commit comments

Comments
 (0)