Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c15d93b
Current scaling: two-stage amax kernel
matthiasdiener Nov 12, 2025
51fab36
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 13, 2025
ae35e4c
bugfix graph capture
matthiasdiener Nov 13, 2025
77a68a7
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 17, 2025
c0d8e73
outline workspace allocation
matthiasdiener Nov 17, 2025
6c3507d
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 18, 2025
3c9de07
Proper allocation of workspace
matthiasdiener Nov 18, 2025
91249cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
be0e0c8
add a test to compare the accuracy of both amax implementations
matthiasdiener Nov 19, 2025
bce34da
add possibility to force using previous (atomic) kernel
matthiasdiener Nov 19, 2025
8c388cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
73c8d4e
2-stage Triton amax
matthiasdiener Nov 20, 2025
6388604
add copyrights
matthiasdiener Nov 20, 2025
9e6586f
don't add extra template to kernel
matthiasdiener Nov 20, 2025
18292bf
make amax_kernel_threads usable in pytorch
matthiasdiener Nov 21, 2025
a389455
update remaining calls to nvte_compute_amax
matthiasdiener Nov 21, 2025
d87ab8a
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 24, 2025
7d9ee16
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Nov 24, 2025
fd5dead
additional copyrights
matthiasdiener Nov 24, 2025
16d3bf9
avoid workspace allocations if NVTE_USE_ATOMIC_AMAX is set
matthiasdiener Nov 24, 2025
50b34aa
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 25, 2025
ef532b1
remove use_block_amax parameter, more cleanups
matthiasdiener Nov 25, 2025
f933ef3
Factor workspace allocation into function
matthiasdiener Nov 25, 2025
7d4054e
expand test slightly
matthiasdiener Nov 25, 2025
63cff98
Revert "expand test slightly"
Nov 25, 2025
c7d44a7
guard by HIP macro, address review comments
matthiasdiener Nov 26, 2025
f92b926
bugfix workspace.data.dptr
matthiasdiener Nov 26, 2025
eba552e
various cleanups
matthiasdiener Nov 26, 2025
0d6a177
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 26, 2025
19901a0
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Nov 26, 2025
8eda427
simplify types in allocate_amax_workspace
matthiasdiener Nov 26, 2025
be6496b
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Nov 26, 2025
ed1a54b
Fixes
matthiasdiener Nov 26, 2025
c8d5bb4
add support for NVTE_USE_ATOMIC_AMAX
matthiasdiener Nov 26, 2025
5a9086a
Fuse amax_reduce + compute_scale kernels
matthiasdiener Nov 26, 2025
6990928
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Dec 1, 2025
9ee618f
fix indentation
matthiasdiener Dec 1, 2025
853bb77
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener Dec 1, 2025
cf402b1
undo non-triton changes
matthiasdiener Dec 1, 2025
2c9cc65
[ROCm] use at::empty(0, fp32) as amax workspace for makeTransformerEn…
wangye805 Dec 7, 2025
e41e1d4
Merge branch 'dev' into speedup-amax-triton
matthiasdiener Dec 8, 2025
862ec74
Merge branch 'yewang12/amax-workspace-fix' into speedup-amax-triton
matthiasdiener Dec 8, 2025
35f2d38
add more tests
matthiasdiener Dec 8, 2025
1cbb68f
Merge branch 'dev' into speedup-amax-triton
matthiasdiener Dec 11, 2025
d7259d1
add more tests and re-add comment
matthiasdiener Dec 15, 2025
42c7ac3
Merge branch 'dev' into speedup-amax-triton
matthiasdiener Dec 15, 2025
ef31ef7
Merge branch 'dev' into speedup-amax-triton
matthiasdiener Dec 18, 2025
25c91e8
restore FP8 current scaling support
matthiasdiener Dec 18, 2025
188b7ca
add test comparing atomic amax and 2-stage
matthiasdiener Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ run_test_config(){
NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.

Also not sure about the runtime cost of adding two new pytests in level 3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.

I added both cases in d7259d1.

Also not sure about the runtime cost of adding two new pytests in level 3

test_numerics.py takes about 5 min, test_fusible_ops.py takes about 1 min (on gfx942), times 2 since we run it with NVTE_USE_ATOMIC_AMAX=0 and =1. Perhaps adding just the test in 188b7ca is enough?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 mins sounds okay for level 3. @ipanfilo , what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing with @wenchenvincent, we concluded that it is worth keeping the extra tests around.

NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=0 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=0 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 triton_kernels/test_cast.py
}

run_test_config_mgpu(){
Expand Down
42 changes: 42 additions & 0 deletions tests/pytorch/triton_kernels/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,45 @@ def test_compute_scale_from_amax(amax_val, force_pow_2_scales, epsilon, fp8_dtyp

torch.testing.assert_close(scale_triton, scale_ref[0], rtol=0.0, atol=0.0)
torch.testing.assert_close(scale_inv_triton, scale_inv_ref[0], rtol=0.0, atol=0.0)


@pytest.mark.parametrize("shape", ((1, 1), (7, 13), (256, 257), (1024, 1024), (2048, 4097)))
@pytest.mark.parametrize("in_dtype", (torch.float16, torch.bfloat16))
@pytest.mark.parametrize("out_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_amax_atomic_vs_two_stage(shape, in_dtype, out_dtype):
import os
device = "cuda"
input_tensor = fill_uniform(shape, dtype=in_dtype)

quantizer_atomic = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device=device)
quantizer_2stage = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device=device)

env_key = "NVTE_USE_ATOMIC_AMAX"
old_env_val = os.environ.get(env_key)

try:
# atomic amax
os.environ[env_key] = "1"

with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()):
out_atomic = te_quantize_triton(input_tensor, quantizer=quantizer_atomic)

# 2-stage amax
os.environ[env_key] = "0"

with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()):
out_2stage = te_quantize_triton(input_tensor, quantizer=quantizer_2stage)

te_compare_results(
out_atomic._get_quantizer().amax,
out_2stage._get_quantizer().amax,
atol=0.0, rtol=0.0,
msg='AMAX results do not match!',
use_torch_semantics=True
)
finally:
# Restore environment
if old_env_val is None:
os.environ.pop(env_key, None)
else:
os.environ[env_key] = old_env_val
3 changes: 0 additions & 3 deletions transformer_engine/pytorch/triton_kernels/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ def te_quantize_triton(
Quantizes the input tensor using a specified quantizer,
with an option to utilize Triton-based `cast_transpose` for performance.
"""
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangye805 Do you remember why current scaling was disabled here (in #374)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall I moved this line to

from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
instead of disabling curent scaling quantizer entirely, in order to resolve a circular inclusion issue

if isinstance(quantizer, Float8CurrentScalingQuantizer):
return tex.quantize(tensor, quantizer, output, noop_flag)
input_tensor = tensor.contiguous()
fake_tensor_type = input_tensor.dtype
if not fake_tensor_type.is_floating_point:
Expand Down
164 changes: 146 additions & 18 deletions transformer_engine/pytorch/triton_kernels/cast_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
te_dtype_to_torch_dtype,
get_fp8_max,
)
import os

##########################################
#### cast_transpose
##########################################
Expand Down Expand Up @@ -189,6 +191,101 @@ def _cast_transpose_triton_current_scaling(A, C, T, stride_am, stride_an, stride
tl.store(T, fp8_a, mask=mask)


AMAX_STAGE1_CONFIGS = [
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 8}, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=8),
]

@triton.autotune(
configs=AMAX_STAGE1_CONFIGS,
key=['M', 'N'],
)
@triton.jit
def _amax_reduce_triton_stage1(
A,
stride_am, stride_an,
M, N,
block_amax, # float32[workspace_size]
num_blocks, # int32[1]
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)

grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N

width = GROUP_M * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size

rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)

A_ptrs = A + rm[:, None] * stride_am + rn[None, :] * stride_an
mask = (rm < M)[:, None] & (rn < N)[None, :]

a = tl.load(A_ptrs, mask=mask, other=0).to(tl.float32)
tile_amax = tl.max(tl.abs(a))

# Store per-program amax in workspace
tl.store(block_amax + pid, tile_amax)

if pid == 0:
tl.store(num_blocks, tl.num_programs(0))

@triton.jit
def _amax_reduce_and_compute_scale_triton(
block_amax, # float32[num_blocks]
num_blocks, # int32[1]
amax_ptr, # float32[1]
scale_ptr, # float32[1]
inv_ptr, # float32[1]
max_fp8, # scalar (float32)
epsilon, # scalar (float32)
value_for_inf, # scalar (float32)
FORCE_POW_2_SCALES: tl.constexpr,
BLOCKSIZE: tl.constexpr,
):
# Reduce per-block amaxes
a = tl.full((), -float('inf'), tl.float32)
offset = 0
num_blocks = tl.load(num_blocks)

while offset < num_blocks:
idx = offset + tl.arange(0, BLOCKSIZE)
mask = idx < num_blocks
vals = tl.load(block_amax + idx, mask=mask, other=-float('inf'))
a = tl.maximum(a, tl.max(vals))
offset += BLOCKSIZE

tl.store(amax_ptr, a)

# Compute scale + inv_scale from amax

# amax < epsilon -> epsilon (NaNs pass through)
a = tl.where(a < epsilon, epsilon, a)

# bad amax (NaN, inf, 0.0) -> scale = 1.0
bad = (a != a) | (tl.abs(a) == float('inf')) | (a == 0.0)

if bad:
s = tl.full((), 1.0, tl.float32)
else:
s = max_fp8 / a
# inf -> scale = value_for_inf
s = tl.where(tl.abs(a) == float('inf'), value_for_inf, s)
if FORCE_POW_2_SCALES:
s = tl.math.exp2(tl.floor(tl.log2(s)))

tl.store(scale_ptr, s)
tl.store(inv_ptr, 1.0 / s)


FP32_EXPONENT_BIAS = tl.constexpr(127)
FP32_MANTISSA_BITS = tl.constexpr(23)
@triton.jit
Expand Down Expand Up @@ -376,28 +473,59 @@ def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans
grid = lambda META: (triton.cdiv(num_rows, META['BLOCK_M']) * triton.cdiv(row_length, META['BLOCK_N']),)

if current_scaling:
# Current scaling:
# 1) global amax reduction
# 2) compute current scale
# 3) cast+transpose with that current scale (otherwise same as delayed)
# 1) global amax reduction
# 2) compute current scale
# 3) cast+transpose with that current scale (otherwise same as delayed)

# global amax
amax_out.fill_(-float("inf"))
_amax_reduce_triton[grid](
input_2d_view,
input_stride_M, input_stride_N,
num_rows, row_length,
amax_out,
)

# Compute scale
fp8_max = get_fp8_max(otype)

_compute_scale_from_amax_triton[(1,)](
amax_out, input_scale, scale_inv_out,
fp8_max, eps, torch.finfo(torch.float32).max,
FORCE_POW_2_SCALES=force_pow_2_scales,
)
nvte_use_atomic_amax = bool( int(os.environ.get('NVTE_USE_ATOMIC_AMAX', '0')) )

if nvte_use_atomic_amax:
# Compute global amax
_amax_reduce_triton[grid](
input_2d_view,
input_stride_M, input_stride_N,
num_rows, row_length,
amax_out,
)

# Compute scale
_compute_scale_from_amax_triton[(1,)](
amax_out, input_scale, scale_inv_out,
fp8_max, eps, torch.finfo(torch.float32).max,
FORCE_POW_2_SCALES=force_pow_2_scales,
)
else:
# 2-stage amax
max_num_amax_stage1_programs = max(
triton.cdiv(num_rows, cfg.kwargs['BLOCK_M']) *
triton.cdiv(row_length, cfg.kwargs['BLOCK_N'])
for cfg in AMAX_STAGE1_CONFIGS
)

block_amax = torch.empty(max_num_amax_stage1_programs, device=input.device,
dtype=torch.float32)

num_blocks = torch.empty(1, device=input.device, dtype=torch.int32)

# Stage 1: per-program tile amax
_amax_reduce_triton_stage1[grid](
input_2d_view,
input_stride_M, input_stride_N,
num_rows, row_length,
block_amax, num_blocks,
)

# Stage 2: reduce per-program maxima into amax_out and compute scale
_amax_reduce_and_compute_scale_triton[(1,)](
block_amax, num_blocks,
amax_out, input_scale, scale_inv_out,
fp8_max, eps, torch.finfo(torch.float32).max,
FORCE_POW_2_SCALES=force_pow_2_scales,
BLOCKSIZE=512,
)

_cast_transpose_triton_current_scaling[grid](input_2d_view, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, get_fp8_max(otype))
else:
Expand Down