|
9 | 9 | fused_silu_mul_fp8_per_tensor_static_quant, |
10 | 10 | ) |
11 | 11 |
|
12 | | -from aiter import ( |
13 | | - rmsnorm2d_fwd, |
14 | | - silu_and_mul, |
15 | | -) |
16 | | - |
17 | 12 | from aiter.test_common import ( |
18 | 13 | checkAllclose, |
19 | 14 | ) |
20 | | - |
21 | | -from aiter.ops.quant import per_tensor_quant_hip |
22 | 15 | import aiter |
23 | 16 | import torch.nn.functional as F |
24 | 17 |
|
@@ -198,17 +191,11 @@ def test_fused_rms_fp8_group_quant(M: int, N1: int, N2: int, dtype): |
198 | 191 | torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1) |
199 | 192 |
|
200 | 193 |
|
201 | | -def rmsnorm2d_fwd_( |
202 | | - x: torch.Tensor, weight: torch.Tensor, eps: float, dim: int |
203 | | -) -> torch.Tensor: |
204 | | - ori_shape = x.shape |
205 | | - x = x.reshape(-1, dim) |
206 | | - return rmsnorm2d_fwd(x, weight, eps).view(ori_shape) |
207 | | - |
208 | | - |
209 | | -def rmsnorm_fp8_quantization_ref(x, w, x_scale, eps, n, rocm_fp8_dtype): |
210 | | - rms_out = rmsnorm2d_fwd_(x, w, eps, n) |
211 | | - quant_out, _ = per_tensor_quant_hip(rms_out, x_scale, rocm_fp8_dtype) |
| 194 | +def rmsnorm_fp8_quantization_ref(x, w, x_scale, eps, rocm_fp8_dtype): |
| 195 | + rms_out = rmsnorm(x.to(torch.float32), w.to(torch.float32), eps).to(x.dtype) |
| 196 | + quant_out = per_tensor_fp8_static_quant( |
| 197 | + rms_out.to(torch.float32), rocm_fp8_dtype, x_scale.to(torch.float32) |
| 198 | + ) |
212 | 199 | return quant_out, rms_out |
213 | 200 |
|
214 | 201 |
|
@@ -248,14 +235,14 @@ def test_rmsnorm_quant_fuse(m, n): |
248 | 235 | ) |
249 | 236 |
|
250 | 237 | # calculate the correct scale value |
251 | | - rms_out = rmsnorm2d_fwd_(x, w, eps, n) |
| 238 | + rms_out = rmsnorm(x.to(torch.float32), w.to(torch.float32), eps) |
252 | 239 | rms_out_abs = torch.abs(rms_out) |
253 | 240 | rms_out_abs_max = torch.max(rms_out_abs) |
254 | 241 | scale_val = rms_out_abs_max / DTYPE_MAX |
255 | 242 | x_scale = torch.tensor((scale_val), dtype=torch.float32, device="cuda") |
256 | 243 |
|
257 | 244 | fp8_x_ref, rms_out_ref = rmsnorm_fp8_quantization_ref( |
258 | | - x, w, x_scale, eps, n, rocm_fp8_dtype |
| 245 | + x, w, x_scale, eps, rocm_fp8_dtype |
259 | 246 | ) |
260 | 247 | fp8_x, rms_out = triton_rmsnorm_fp8_quantization_fuse( |
261 | 248 | x, w, x_scale, eps, rocm_fp8_dtype |
@@ -658,9 +645,13 @@ def silu_mul_fp8_quantization_ref(x, x_scale, rocm_fp8_dtype): |
658 | 645 | m, n2 = x.shape |
659 | 646 | assert n2 % 2 == 0 |
660 | 647 | n = n2 // 2 |
661 | | - silu_out = torch.empty((m, n), dtype=x.dtype, device=x.device) |
662 | | - silu_and_mul(silu_out, x) |
663 | | - quant_out, _ = per_tensor_quant_hip(silu_out, x_scale, rocm_fp8_dtype) |
| 648 | + x1, x2 = x.split([n, n], dim=-1) |
| 649 | + silu_out = ( |
| 650 | + (F.silu(x1.to(torch.float32)) * x2.to(torch.float32)) |
| 651 | + .to(x.dtype) |
| 652 | + .to(torch.float32) |
| 653 | + ) |
| 654 | + quant_out = per_tensor_fp8_static_quant(silu_out, rocm_fp8_dtype, x_scale) |
664 | 655 | return quant_out |
665 | 656 |
|
666 | 657 |
|
@@ -688,8 +679,12 @@ def test_silu_mul_quant_fuse(m, n): |
688 | 679 | ) |
689 | 680 |
|
690 | 681 | # calculate the correct scale value |
691 | | - silu_out = torch.empty((m, n), dtype=x.dtype, device=x.device) |
692 | | - silu_and_mul(silu_out, x) |
| 682 | + x1, x2 = x.split([n, n], dim=-1) |
| 683 | + silu_out = ( |
| 684 | + (F.silu(x1.to(torch.float32)) * x2.to(torch.float32)) |
| 685 | + .to(x.dtype) |
| 686 | + .to(torch.float32) |
| 687 | + ) |
693 | 688 | silu_out_abs = torch.abs(silu_out) |
694 | 689 | silu_out_abs_max = torch.max(silu_out_abs) |
695 | 690 | scale_val = silu_out_abs_max / DTYPE_MAX |
|
0 commit comments