@@ -5370,6 +5370,9 @@ def triton_quantize_nvfp4(
53705370 # Pass a dummy pointer; the kernel won't load from it.
53715371 global_scale = x .new_empty (())
53725372
5373+ # Use int64 indexing when pointer offsets can exceed INT32_MAX
5374+ use_int64_indexing = M * N > 2 ** 31 - 1
5375+
53735376 triton_quantize_nvfp4_kernel [grid ](
53745377 x ,
53755378 global_scale ,
@@ -5389,6 +5392,8 @@ def triton_quantize_nvfp4(
53895392 USE_PRECISE_MATH = use_precise_math ,
53905393 # pyre-ignore[6]
53915394 USE_GLOBAL_SCALE = use_global_scale ,
5395+ # pyre-ignore[6]
5396+ USE_INT64_INDEXING = use_int64_indexing ,
53925397 )
53935398
53945399 # reshape back to original shape
@@ -5413,6 +5418,7 @@ def triton_quantize_nvfp4_kernel(
54135418 USE_E8M0_SCALE : tl .constexpr ,
54145419 USE_PRECISE_MATH : tl .constexpr ,
54155420 USE_GLOBAL_SCALE : tl .constexpr ,
5421+ USE_INT64_INDEXING : tl .constexpr ,
54165422):
54175423 E4M3_EPS = 1.5258789e-05
54185424 FP8_E4M3_MAX = 448.0
@@ -5444,6 +5450,10 @@ def triton_quantize_nvfp4_kernel(
54445450
54455451 offs_m = pid_m * M_PER_BLOCK + tl .arange (0 , M_PER_BLOCK )[:, None ]
54465452 offs_n = pid_n * 64 + tl .arange (0 , 64 )[None , :]
5453+ if USE_INT64_INDEXING :
5454+ offs_m = offs_m .to (tl .int64 )
5455+ offs_n = offs_n .to (tl .int64 )
5456+
54475457 if USE_MASK :
54485458 mask = (offs_m < M ) & (offs_n < N )
54495459 other = 0.0
@@ -5456,9 +5466,8 @@ def triton_quantize_nvfp4_kernel(
54565466 else :
54575467 global_scale = 1.0
54585468
5459- x = tl .load (
5460- x_ptr + offs_m * stride_xm + offs_n * stride_xn , mask = mask , other = other
5461- ) # [M_PER_BLOCK, 64]
5469+ load_offsets = offs_m * stride_xm + offs_n * stride_xn
5470+ x = tl .load (x_ptr + load_offsets , mask = mask , other = other ) # [M_PER_BLOCK, 64]
54625471 x_blocks = x .to (tl .float32 ).reshape (M_PER_BLOCK , 4 , 16 ) # [M_PER_BLOCK, 4, 16]
54635472
54645473 # Block-wise max
@@ -5519,7 +5528,13 @@ def triton_quantize_nvfp4_kernel(
55195528 mask = (offs_m < M ) & (offs_n < N // 2 )
55205529 else :
55215530 mask = None
5522- tl .store (q_ptr + offs_m * (N // 2 ) + offs_n , x_fp4x2 , mask = mask )
5531+
5532+ if USE_INT64_INDEXING :
5533+ offs_m = offs_m .to (tl .int64 )
5534+ offs_n = offs_n .to (tl .int64 )
5535+
5536+ store_offsets = offs_m * (N // 2 ) + offs_n
5537+ tl .store (q_ptr + store_offsets , x_fp4x2 , mask = mask )
55235538
55245539
55255540@triton .jit
0 commit comments