Skip to content

Commit 6bbd03c

Browse files
[TE] Implement Triton current scaling (#341)
1 parent 5685b2c commit 6bbd03c

File tree

7 files changed

+248
-24
lines changed

7 files changed

+248
-24
lines changed

ci/pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ run_test_config(){
8181
run_default_fa 1 test_parallel_cross_entropy.py
8282
NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_numerics.py
8383
NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
84+
NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py
8485
}
8586

8687
run_test_config_mgpu(){

tests/pytorch/test_numerics.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -737,10 +737,6 @@ def test_gpt_full_activation_recompute(
737737
pytest.skip(reason_for_no_fp8)
738738
if recipe.mxfp8() and not mxfp8_available:
739739
pytest.skip(reason_for_no_mxfp8)
740-
if IS_HIP_EXTENSION:
741-
use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) )
742-
if fp8 and recipe.float8_current_scaling() and use_cast_transpose_triton:
743-
pytest.skip("Float8 Current Scaling unsupported for full recompute.")
744740
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
745741
pytest.skip(reason_for_no_fp8_block_scaling)
746742

@@ -1959,9 +1955,6 @@ def test_grouped_linear_accuracy(
19591955
if IS_HIP_EXTENSION:
19601956
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8:
19611957
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
1962-
use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) )
1963-
if fp8 and recipe.float8_current_scaling() and use_cast_transpose_triton:
1964-
pytest.skip("Float8 Current Scaling unsupported for grouped linear accuracy.")
19651958
if fp8 and not fp8_available:
19661959
pytest.skip(reason_for_no_fp8)
19671960
if fp8 and recipe.mxfp8() and not mxfp8_available:

tests/pytorch/triton_kernels/test_cast.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
22
# License for AMD contributions = MIT. See LICENSE for more information
33

4-
import os
54
import pytest
65
import torch
76

87
from transformer_engine.pytorch.triton_kernels.cast import te_quantize_triton
9-
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
8+
from transformer_engine.pytorch.triton_kernels.cast_transpose import _compute_scale_from_amax_triton
9+
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
1010
from transformer_engine.pytorch.triton_kernels.common import te_dtype_to_torch_dtype
1111
import transformer_engine_torch as tex
1212
from test_common import te_compare_results, fill_uniform, get_tolerances
13+
from transformer_engine.pytorch.fp8 import fp8_autocast
14+
from transformer_engine.common import recipe
15+
from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type
1316

17+
@pytest.mark.parametrize("scaling", ("delayed", "current"))
1418
@pytest.mark.parametrize("shape",
1519
[
1620
(16 ),
@@ -32,17 +36,30 @@
3236
])
3337
@pytest.mark.parametrize("in_dtype", [torch.float32, torch.float16, torch.bfloat16])
3438
@pytest.mark.parametrize("out_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
35-
def test_quantize(shape, in_dtype, out_dtype):
39+
def test_quantize(scaling, shape, in_dtype, out_dtype):
3640
input_tensor = fill_uniform(shape, dtype=in_dtype)
3741

38-
scale_tensor = torch.rand(1, dtype=torch.float32, device='cuda') * 3.0 - 2.0
39-
amax_tensor = torch.zeros(1, dtype=torch.float32, device='cuda')
40-
triton_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
42+
if scaling == "current":
43+
triton_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda")
44+
tex_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda")
45+
46+
with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()):
47+
quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer)
48+
quantized_out_tex = tex.quantize(input_tensor, tex_quantizer)
49+
50+
elif scaling == "delayed":
51+
scale_tensor = torch.rand(1, dtype=torch.float32, device='cuda') * 3.0 - 2.0
52+
amax_tensor = torch.zeros(1, dtype=torch.float32, device='cuda')
53+
54+
triton_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
55+
tex_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
56+
57+
quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer)
58+
quantized_out_tex = tex.quantize(input_tensor, tex_quantizer)
59+
60+
else:
61+
raise ValueError(f"unknown scaling method {scaling}")
4162

42-
quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer)
43-
44-
tex_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
45-
quantized_out_tex = tex.quantize(input_tensor, tex_quantizer)
4663
torch_out_dtype = te_dtype_to_torch_dtype(out_dtype)
4764

4865
atol_q, rtol_q = get_tolerances(torch_out_dtype)
@@ -112,3 +129,41 @@ def test_quantize_bad_transpose(t_shape, fp8_dtype):
112129
quantized_output._transpose = torch.empty(t_shape, device='cuda')
113130

114131
te_quantize_triton(input_tensor, quantizer=quantizer, output=quantized_output)
132+
133+
134+
@pytest.mark.parametrize("amax_val", (0.0, float('nan'), float('inf'), -float('inf'), 1.0, 1e-8, 123.456))
135+
@pytest.mark.parametrize("force_pow_2_scales", (False, True))
136+
@pytest.mark.parametrize("epsilon", (0.0, 1e-3, 100.0))
137+
@pytest.mark.parametrize("fp8_dtype", (get_torch_float8_e4m3_type(), get_torch_float8_e5m2_type()))
138+
def test_compute_scale_from_amax(amax_val, force_pow_2_scales, epsilon, fp8_dtype):
139+
max_fp8 = torch.finfo(fp8_dtype).max
140+
value_for_inf = float(torch.finfo(torch.float32).max)
141+
142+
amax_list = [torch.tensor(amax_val, dtype=torch.float32, device="cuda")]
143+
144+
# TEX path - TEX expects lists for (amaxes, scales, inv_scales)
145+
scale_ref = [torch.empty((), dtype=torch.float32, device="cuda")]
146+
scale_inv_ref = [torch.empty((), dtype=torch.float32, device="cuda")]
147+
148+
chunk_size = 2048 * 32 # arbitrary
149+
overflow_buf = torch.zeros(1, dtype=torch.int32, device="cuda")
150+
tex.multi_tensor_compute_scale_and_scale_inv(
151+
chunk_size,
152+
overflow_buf,
153+
[amax_list, scale_ref, scale_inv_ref],
154+
max_fp8,
155+
force_pow_2_scales,
156+
epsilon,
157+
)
158+
159+
# Triton path & comparison
160+
scale_triton = torch.empty((), dtype=torch.float32, device="cuda")
161+
scale_inv_triton = torch.empty((), dtype=torch.float32, device="cuda")
162+
_compute_scale_from_amax_triton[(1,)](
163+
amax_list[0], scale_triton, scale_inv_triton,
164+
float(max_fp8), float(epsilon), float(value_for_inf),
165+
FORCE_POW_2_SCALES=force_pow_2_scales,
166+
)
167+
168+
torch.testing.assert_close(scale_triton, scale_ref[0], rtol=0.0, atol=0.0)
169+
torch.testing.assert_close(scale_inv_triton, scale_inv_ref[0], rtol=0.0, atol=0.0)

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# This file was modified for portability to AMDGPU
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
13
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
24
#
35
# See LICENSE for license information.
@@ -8,6 +10,7 @@
810

911
import functools
1012
import torch
13+
import os
1114

1215
import transformer_engine_torch as tex
1316

@@ -49,6 +52,7 @@
4952
prepare_for_saving,
5053
restore_from_saved,
5154
)
55+
from torch.utils.cpp_extension import IS_HIP_EXTENSION
5256

5357
__all__ = ["GroupedLinear"]
5458

@@ -125,9 +129,20 @@ def forward(
125129
recipe = FP8GlobalStateManager.get_fp8_recipe()
126130
if hasattr(recipe, "fp8_gemm_fprop"):
127131
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
128-
inputmats = tex.fused_multi_quantize(
129-
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
130-
)
132+
133+
if IS_HIP_EXTENSION and bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ):
134+
# The Triton path has no equivalent for tex.fused_multi_quantize()
135+
inputmats = []
136+
for i, x in enumerate(inputmats_no_fp8):
137+
qi = input_quantizers[i]
138+
dst = qi.make_empty(x.shape, dtype=x.dtype, device=x.device, requires_grad=False)
139+
qi.update_quantized(x, dst, noop_flag=None)
140+
inputmats.append(dst)
141+
else:
142+
inputmats = tex.fused_multi_quantize(
143+
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
144+
)
145+
131146
weights_fp8 = []
132147
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
133148
# FP8 cast to workspace buffer

transformer_engine/pytorch/tensor/float8_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,12 @@ def update_quantized(
247247
src = src.contiguous()
248248

249249
# Launch cast kernel
250-
tex.quantize(src, self, dst, noop_flag)
250+
if IS_HIP_EXTENSION:
251+
use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) )
252+
quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize
253+
quantize_func(src, self, dst, noop_flag)
254+
else:
255+
tex.quantize(src, self, dst, noop_flag)
251256

252257
# Update FP8 dtype
253258
dst._fp8_dtype = self.dtype

transformer_engine/pytorch/triton_kernels/cast.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def te_quantize_triton(
9696
cast_out = out._data
9797
trans_out = out._transpose
9898
scale_inv_out = out._scale_inv
99+
100+
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
101+
is_current_scaling = isinstance(quantizer, Float8CurrentScalingQuantizer)
102+
99103
te_cast_transpose_noop_triton(
100104
input_tensor,
101105
noop_flag,
@@ -104,7 +108,10 @@ def te_quantize_triton(
104108
trans_out=trans_out,
105109
amax_out=amax_out,
106110
scale_inv_out=scale_inv_out,
107-
otype=otype
111+
otype=otype,
112+
current_scaling=is_current_scaling,
113+
eps = getattr(quantizer, "amax_epsilon", 0.0),
114+
force_pow_2_scales = getattr(quantizer, "force_pow_2_scales", False),
108115
)
109116

110117
else:

transformer_engine/pytorch/triton_kernels/cast_transpose.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,80 @@
1616
#### cast_transpose
1717
##########################################
1818

19+
@triton.autotune(
20+
configs=[
21+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
22+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 8}, num_warps=4),
23+
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=8),
24+
],
25+
key=['M', 'N'],
26+
)
27+
@triton.jit
28+
def _amax_reduce_triton(
29+
A,
30+
stride_am, stride_an,
31+
M, N,
32+
amax_ptr, # float32[1], initialize to -inf on host
33+
BLOCK_M: tl.constexpr,
34+
BLOCK_N: tl.constexpr,
35+
GROUP_M: tl.constexpr,
36+
):
37+
pid = tl.program_id(0)
38+
39+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
40+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
41+
42+
width = GROUP_M * grid_n
43+
group_id = pid // width
44+
group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M)
45+
pid_m = group_id * GROUP_M + (pid % group_size)
46+
pid_n = (pid % width) // group_size
47+
48+
rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
49+
rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
50+
51+
A_ptrs = A + rm[:, None] * stride_am + rn[None, :] * stride_an
52+
mask = (rm < M)[:, None] & (rn < N)[None, :]
53+
54+
a = tl.load(A_ptrs, mask=mask, other=0).to(tl.float32)
55+
tile_amax = tl.max(tl.abs(a))
56+
# accumulate tile-wise max into global amax
57+
tl.atomic_max(amax_ptr, tile_amax, sem='relaxed')
58+
59+
60+
@triton.jit
61+
def _compute_scale_from_amax_triton(
62+
amax_ptr,
63+
scale_ptr,
64+
inv_ptr,
65+
max_fp8,
66+
epsilon,
67+
value_for_inf,
68+
FORCE_POW_2_SCALES: tl.constexpr,
69+
):
70+
# This implementation mimics transformer_engine::compute_scale_from_amax()
71+
72+
a = tl.load(amax_ptr).to(tl.float32)
73+
74+
# amax < epsilon -> epsilon (NaNs pass through)
75+
a = tl.where(a < epsilon, epsilon, a)
76+
77+
# bad amax (NaN, inf, 0.0) -> scale = 1.0
78+
bad = (a != a) | (tl.abs(a) == float('inf')) | (a == 0.0)
79+
80+
if bad:
81+
s = tl.full((), 1.0, tl.float32)
82+
else:
83+
s = max_fp8 / a
84+
# inf -> scale = value_for_inf
85+
s = tl.where(tl.abs(a) == float('inf'), value_for_inf, s)
86+
if FORCE_POW_2_SCALES:
87+
s = tl.math.exp2(tl.floor(tl.log2(s)))
88+
89+
tl.store(scale_ptr, s)
90+
tl.store(inv_ptr, 1.0 / s)
91+
92+
1993
@triton.autotune(
2094
configs=[
2195
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
@@ -69,6 +143,52 @@ def _cast_transpose_triton(A, noop_ptr, C, T, stride_am, stride_an, stride_bn, s
69143
scale_inv_out = tl.fdiv(1.0, scale)
70144
tl.store(scale_inv_ptr, scale_inv_out)
71145

146+
147+
@triton.autotune(
148+
configs=[
149+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
150+
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 8}, num_warps=4),
151+
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=8),
152+
],
153+
key=['M', 'N']
154+
)
155+
@triton.jit
156+
def _cast_transpose_triton_current_scaling(A, C, T, stride_am, stride_an, stride_bn, stride_bm, M, N, scale_ptr, max_fp8: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_M: tl.constexpr):
157+
# Similar (but slightly optimized) version of the delayed scaling kernel
158+
# implemented in _cast_transpose_triton().
159+
pid = tl.program_id(0)
160+
scale = tl.load(scale_ptr)
161+
162+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
163+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
164+
165+
width = GROUP_M * grid_n
166+
group_id = pid // width
167+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
168+
pid_m = group_id * GROUP_M + (pid % group_size)
169+
pid_n = (pid % width) // group_size
170+
171+
rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
172+
rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
173+
A = A + rm[:, None] * stride_am + rn[None, :] * stride_an
174+
mask = (rm < M)[:, None] & (rn < N)[None, :]
175+
a = tl.load(A, mask=mask)
176+
a = a.to(tl.float32)
177+
178+
scaled_a = a * scale
179+
scaled_a = tl.clamp(scaled_a, -max_fp8, max_fp8)
180+
fp8_a = scaled_a.to(C.type.element_ty)
181+
C = C + rm[:, None] * stride_am + rn[None, :] * stride_an
182+
tl.store(C, fp8_a, mask=mask)
183+
184+
# rematerialize to save registers
185+
rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
186+
rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
187+
T = T + rm[:, None] * stride_bm + rn[None, :] * stride_bn
188+
mask = (rm < M)[:, None] & (rn < N)[None, :]
189+
tl.store(T, fp8_a, mask=mask)
190+
191+
72192
FP32_EXPONENT_BIAS = tl.constexpr(127)
73193
FP32_MANTISSA_BITS = tl.constexpr(23)
74194
@triton.jit
@@ -232,7 +352,7 @@ def _dequantize_mxfp8_triton(
232352

233353
# Reshapes input of any given shape to 2D for processing,
234354
# then uses the Triton kernel to perform casting and transposition efficiently.
235-
def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans_out, amax_out, scale_inv_out, otype):
355+
def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans_out, amax_out, scale_inv_out, otype, current_scaling, eps, force_pow_2_scales):
236356

237357
row_length = input.shape[-1] if len(input.shape) > 0 else 1
238358
num_rows = input.numel() // row_length
@@ -254,7 +374,35 @@ def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans
254374
use_noop = False
255375

256376
grid = lambda META: (triton.cdiv(num_rows, META['BLOCK_M']) * triton.cdiv(row_length, META['BLOCK_N']),)
257-
_cast_transpose_triton[grid](input_2d_view, noop_flag, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, amax_out, scale_inv_out, get_fp8_max(otype), use_noop)
377+
378+
if current_scaling:
379+
# Current scaling:
380+
# 1) global amax reduction
381+
# 2) compute current scale
382+
# 3) cast+transpose with that current scale (otherwise same as delayed)
383+
384+
# global amax
385+
amax_out.fill_(-float("inf"))
386+
_amax_reduce_triton[grid](
387+
input_2d_view,
388+
input_stride_M, input_stride_N,
389+
num_rows, row_length,
390+
amax_out,
391+
)
392+
393+
# Compute scale
394+
fp8_max = get_fp8_max(otype)
395+
396+
_compute_scale_from_amax_triton[(1,)](
397+
amax_out, input_scale, scale_inv_out,
398+
fp8_max, eps, torch.finfo(torch.float32).max,
399+
FORCE_POW_2_SCALES=force_pow_2_scales,
400+
)
401+
402+
_cast_transpose_triton_current_scaling[grid](input_2d_view, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, get_fp8_max(otype))
403+
else:
404+
# Delayed scaling
405+
_cast_transpose_triton[grid](input_2d_view, noop_flag, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, amax_out, scale_inv_out, get_fp8_max(otype), use_noop)
258406

259407
def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None):
260408
row_length = input.shape[-1] if len(input.shape) > 0 else 1

0 commit comments

Comments
 (0)