Skip to content

Commit 7f35284

Browse files
rmsnorm multiple datatype tests (#705)
* add multiple in_datatype and out_datatype to the tests * change init method to fix TE integration issue
1 parent 2f471c0 commit 7f35284

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

python/perf-kernels/rmsnorm.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, rsigma_ptr, input_row_stride, outpu
6464

6565
# Accumulate sum of squares
6666
n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1
67-
sum_squares: tl.float32 = 0.
67+
# older version of triton doesn't accept below init
68+
# sum_squares: tl.float32 = 0.
69+
# however, with type promoting rule in triton, sum_squares should be always fp32 with below init
70+
sum_squares = 0.
6871
for blk_idx in tl.range(0, n_cols_blks, num_stages=2):
6972
cols = blk_idx * BLOCK_SIZE + col_offsets
7073
input_ptrs = row_input_ptr + cols
@@ -147,54 +150,71 @@ def triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_siz
147150
return y, rsigma
148151

149152

150-
def torch_rmsnorm(x, g, ZERO_CENTERED_GAMMA, epsilon=1e-6):
153+
def torch_rmsnorm(x, g, ZERO_CENTERED_GAMMA, out_dtype=torch.float16, epsilon=1e-6):
151154
M, N = x.shape
152-
rms = torch.sqrt(torch.sum(x * x, dim=-1) * 1 / N)
155+
# cast to float32 as the triton kernel
156+
x_f32 = x.float()
157+
g_f32 = g.float()
158+
rms = torch.sqrt(torch.sum(x_f32 * x_f32, dim=-1) * 1 / N)
153159
rsigma = 1.0 / rms
154160
if (ZERO_CENTERED_GAMMA):
155-
g += 1
156-
rms_norm = x * rsigma.unsqueeze(1) * g
157-
rms_norm = rms_norm.to(x.dtype)
161+
g_f32 += 1
162+
rms_norm_f32 = x_f32 * rsigma.unsqueeze(1) * g_f32
163+
rms_norm = rms_norm_f32.to(out_dtype)
158164
return rms_norm, rsigma
159165

160166

167+
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}
168+
169+
170+
@pytest.mark.parametrize("in_dtype_str", ["fp32", "fp16", "bf16"])
171+
@pytest.mark.parametrize("out_dtype_str", ["fp32", "fp16", "bf16"])
161172
@pytest.mark.parametrize('ZERO_CENTERED_GAMMA', [True, False])
162173
@pytest.mark.parametrize('M, N', [
163174
(1, 4),
164175
(2, 10),
165176
(8192, 4096),
166177
(4096, 8192),
167-
(1, 8192),
168178
(1, 31744),
169179
(3, 65536),
170180
(873, 1245),
171181
])
172-
def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA):
182+
def test_rmsnorm(M, N, ZERO_CENTERED_GAMMA, in_dtype_str, out_dtype_str):
183+
in_dtype = arg_to_torch_dtype[in_dtype_str]
184+
out_dtype = arg_to_torch_dtype[out_dtype_str]
173185
torch.manual_seed(0)
174-
x = torch.randn(M, N, device='cuda')
175-
y = torch.zeros_like(x, device='cuda')
186+
x = torch.randn(M, N, device='cuda', dtype=in_dtype)
187+
y = torch.zeros_like(x, device='cuda', dtype=out_dtype)
176188
rsigma = torch.empty((M, ), device='cuda', dtype=torch.float32)
177189

178190
n_rows, n_cols = x.shape
179191
MAX_FUSED_SIZE = 65536 // x.element_size()
180192
blk_size = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols))
181193
USE_BLOCKED = n_cols > blk_size
182194
NUM_PRGMS = min(n_rows, get_num_sms())
183-
g = torch.ones((1, N), device='cuda')
195+
g = torch.ones((1, N), device='cuda', dtype=in_dtype)
184196

185197
y_triton, rsigma_triton = triton_rmsnorm(x, y, g, rsigma, n_rows, n_cols, ZERO_CENTERED_GAMMA, blk_size,
186198
USE_BLOCKED, NUM_PRGMS)
187199

188-
y_torch, rsigma_torch = torch_rmsnorm(x, g, ZERO_CENTERED_GAMMA)
200+
y_torch, rsigma_torch = torch_rmsnorm(x, g, ZERO_CENTERED_GAMMA, out_dtype)
189201

190-
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
191-
assert torch.allclose(rsigma_triton, rsigma_torch), (rsigma_triton, rsigma_torch)
202+
if out_dtype in (torch.float16, torch.bfloat16):
203+
atol, rtol = 1e-3, 1e-2
204+
else:
205+
# float32 typically can be tighter
206+
atol, rtol = 1e-5, 1e-5
192207

208+
assert y_triton.dtype == out_dtype, f"y_triton has dtype={y_triton.dtype}, expected {out_dtype}"
209+
assert y_torch.dtype == out_dtype, f"y_torch has dtype={y_torch.dtype}, expected {out_dtype}"
193210

194-
#Benchmark
195-
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}
211+
assert torch.allclose(y_triton, y_torch, atol=atol, rtol=rtol), \
212+
f"Mismatch in 'y' (in={in_dtype_str}, out={out_dtype_str})"
213+
assert torch.allclose(rsigma_triton, rsigma_torch, atol=atol, rtol=rtol), \
214+
f"Mismatch in 'rsigma' (in={in_dtype_str}, out={out_dtype_str})"
196215

197216

217+
#Benchmark
198218
def model_benchmark_configs(args):
199219
config_file = args.model_configs
200220
configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model)

0 commit comments

Comments
 (0)