Skip to content

Commit c7ff835

Browse files
author
pytorchbot
committed
2026-03-21 nightly release (0e17236)
1 parent 6438e23 commit c7ff835

File tree

8 files changed

+50
-8
lines changed

8 files changed

+50
-8
lines changed

.github/scripts/nova_dir.bash

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@
99
MSLK_DIR="/__w/MSLK/MSLK"
1010
export MSLK_REPO="${MSLK_DIR}/${REPOSITORY}"
1111

12-
export BUILD_FROM_NOVA=0
12+
################################################################################
13+
# Because we have a custom setup.py with extra flags, we have to do clean /
14+
# build_wheel during the pre-script stage, since we have no control over the
15+
# invocation of setup.py in the Nova build stage.
16+
#
17+
# As such, set the flag here so that setup.py will skip these steps in Nova
18+
# workflow in the build stage.
19+
################################################################################
20+
export BUILD_FROM_NOVA=1
1321

1422
# Disable HIP FMHA build in the manywheel CI (the runner is too small)
1523
export MSLK_BUILD_HIP_FMHA=0

.github/scripts/nova_prescript.bash

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,14 @@ if [[ ${CHANNEL} == "" ]]; then
136136
export CHANNEL="nightly"
137137
fi
138138

139+
################################################################################
140+
# Build the wheel
141+
#
142+
# The build is performed in the pre-script stage of the build workflow since we
143+
# have no control over the invocation of setup.py in the actual build stage.
144+
################################################################################
145+
146+
build_mslk_package "${BUILD_ENV_NAME}" "${CHANNEL}" "${mslk_build_target}/${mslk_build_variant}"
139147
end_time=$(date +%s)
140148
runtime=$((end_time-start_time))
141-
start_time=${end_time}
142-
echo "[NOVA] Time taken to prepare to build the package: ${runtime} seconds / $(display_time ${runtime})"
149+
echo "[NOVA] Time taken to build the package: ${runtime} seconds / $(display_time ${runtime})"

mslk/attention/fmha/utils/op_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import torch
1111

12+
from . import cpp_lib as _cpp_lib # noqa: F401 -- loads _C_hip native extension
13+
1214

1315
def get_operator(library: str, name: str):
1416
def no_such_operator(*args, **kwargs):

mslk/quantize/triton/fp4_quantize.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5370,6 +5370,9 @@ def triton_quantize_nvfp4(
53705370
# Pass a dummy pointer; the kernel won't load from it.
53715371
global_scale = x.new_empty(())
53725372

5373+
# Use int64 indexing when pointer offsets can exceed INT32_MAX
5374+
use_int64_indexing = M * N > 2**31 - 1
5375+
53735376
triton_quantize_nvfp4_kernel[grid](
53745377
x,
53755378
global_scale,
@@ -5389,6 +5392,8 @@ def triton_quantize_nvfp4(
53895392
USE_PRECISE_MATH=use_precise_math,
53905393
# pyre-ignore[6]
53915394
USE_GLOBAL_SCALE=use_global_scale,
5395+
# pyre-ignore[6]
5396+
USE_INT64_INDEXING=use_int64_indexing,
53925397
)
53935398

53945399
# reshape back to original shape
@@ -5413,6 +5418,7 @@ def triton_quantize_nvfp4_kernel(
54135418
USE_E8M0_SCALE: tl.constexpr,
54145419
USE_PRECISE_MATH: tl.constexpr,
54155420
USE_GLOBAL_SCALE: tl.constexpr,
5421+
USE_INT64_INDEXING: tl.constexpr,
54165422
):
54175423
E4M3_EPS = 1.5258789e-05
54185424
FP8_E4M3_MAX = 448.0
@@ -5444,6 +5450,10 @@ def triton_quantize_nvfp4_kernel(
54445450

54455451
offs_m = pid_m * M_PER_BLOCK + tl.arange(0, M_PER_BLOCK)[:, None]
54465452
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
5453+
if USE_INT64_INDEXING:
5454+
offs_m = offs_m.to(tl.int64)
5455+
offs_n = offs_n.to(tl.int64)
5456+
54475457
if USE_MASK:
54485458
mask = (offs_m < M) & (offs_n < N)
54495459
other = 0.0
@@ -5456,9 +5466,8 @@ def triton_quantize_nvfp4_kernel(
54565466
else:
54575467
global_scale = 1.0
54585468

5459-
x = tl.load(
5460-
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
5461-
) # [M_PER_BLOCK, 64]
5469+
load_offsets = offs_m * stride_xm + offs_n * stride_xn
5470+
x = tl.load(x_ptr + load_offsets, mask=mask, other=other) # [M_PER_BLOCK, 64]
54625471
x_blocks = x.to(tl.float32).reshape(M_PER_BLOCK, 4, 16) # [M_PER_BLOCK, 4, 16]
54635472

54645473
# Block-wise max
@@ -5519,7 +5528,13 @@ def triton_quantize_nvfp4_kernel(
55195528
mask = (offs_m < M) & (offs_n < N // 2)
55205529
else:
55215530
mask = None
5522-
tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask)
5531+
5532+
if USE_INT64_INDEXING:
5533+
offs_m = offs_m.to(tl.int64)
5534+
offs_n = offs_n.to(tl.int64)
5535+
5536+
store_offsets = offs_m * (N // 2) + offs_n
5537+
tl.store(q_ptr + store_offsets, x_fp4x2, mask=mask)
55235538

55245539

55255540
@triton.jit

test/attention/fmha/test_fmha_merge_attentions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .utils import (
2424
assert_allclose,
25+
cuda_only,
2526
disable_on_rocm,
2627
sm80_or_better_only,
2728
UNSUPPORTED_OP_PASSES,
@@ -479,6 +480,7 @@ def test_merge_attentions_sharedinput(
479480
)
480481

481482

483+
@cuda_only
482484
@sm80_or_better_only
483485
@pytest.mark.parametrize("bmghk", (False, True))
484486
def test_merge_attentions_against_ref(bmghk: bool):
@@ -685,6 +687,7 @@ def test_merge_training_zilch():
685687

686688

687689
@sm80_or_better_only
690+
@cuda_only
688691
def test_merge_training_undilate():
689692
torch.manual_seed(1)
690693

test/attention/fmha/test_fmha_split_blocks_fairinternal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_split_blocks_for_decoding():
4646
assert (chunked_bias.k_seqinfo.seqstart >= attn_bias.k_seqinfo.seqstart).all()
4747

4848

49+
@cuda_only
4950
def test_split_blocks_for_decoding_with_paged():
5051
torch.manual_seed(0)
5152
max_len_kv = 2048

test/attention/fmha/test_mem_eff_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def test_dropout_ck(q_len, kv_len, batch_size, k_len, p, seed, attn_bias):
313313
def test_dropout_backward_ck(q_len, kv_len, batch_size, k, p):
314314
op = fmha.ck.FwOp
315315
dtype = torch.float16
316-
if not op.is_available():
316+
if not fmha.ck.BwOp.is_available():
317317
if UNSUPPORTED_OP_PASSES:
318318
return
319319
pytest.skip()
@@ -614,6 +614,7 @@ def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]):
614614

615615

616616
@sm75_or_better_only
617+
@cuda_only
617618
def test_unsupported_dropout_combine_flash_cutlass() -> None:
618619
q = torch.empty(
619620
[1, 4, 1, 16], device="cuda", dtype=torch.float16, requires_grad=True
@@ -1893,6 +1894,10 @@ def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -> None:
18931894
if UNSUPPORTED_OP_PASSES:
18941895
return
18951896
pytest.skip("Op is not available")
1897+
if (not not torch.version.hip) and not fmha.ck.BwOp.is_available():
1898+
if UNSUPPORTED_OP_PASSES:
1899+
return
1900+
pytest.skip("Op is not available")
18961901
torch._dynamo.reset_code_caches() # avoids hitting recompilation limit
18971902
B, M, H, K = 1, 256, 2, 64
18981903
q, k, v, bias = create_tensors(

test/quantize/triton/fp4_quantize_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def test_fake_quantize_nvfp4_per_tensor(
284284
(4000, 4096), # large matrix with m padding
285285
(4096, 4080), # large square matrix with n padding
286286
(4000, 4080), # large square matrix with m and n padding
287+
(147456, 15360), # > int32 addressing
287288
],
288289
)
289290
@pytest.mark.parametrize("use_global_scale", [True, False])

0 commit comments

Comments
 (0)