Skip to content

Commit a53a36e

Browse files
nanz-nvvasunvidia
authored andcommitted
1. allocate stashing buffer based on avg token count if STASH_BUFFER_SIZE_FACTOR is positive.
2. fix int32 overflow in some triton kernels when token count is large 3. fix a problem where restored activation might get deallocate prematurely
1 parent e08e0b9 commit a53a36e

File tree

3 files changed

+117
-22
lines changed

3 files changed

+117
-22
lines changed

megatron/core/fusions/fused_bias_swiglu.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -311,16 +311,17 @@ def _weighted_swiglu_fwd_kernel(
311311
# Strided access: each block handles tokens [pid, pid+num_blocks, ...]
312312
token_idx = pid
313313
while token_idx < num_tokens:
314+
token_idx_i64 = token_idx.to(tl.int64)
314315
# Load weight for this token
315-
weight = tl.load(weights_ptr + token_idx)
316+
weight = tl.load(weights_ptr + token_idx_i64)
316317

317318
# Process hidden dimension
318319
for h_offset in range(0, hidden_size, BLOCK_SIZE):
319320
h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size
320321

321322
# Load input chunks (gate and value)
322-
input_offset_1 = token_idx * (hidden_size * 2) + h_offset
323-
input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset
323+
input_offset_1 = token_idx_i64 * (hidden_size * 2) + h_offset
324+
input_offset_2 = token_idx_i64 * (hidden_size * 2) + hidden_size + h_offset
324325

325326
y1 = tl.load(
326327
input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0
@@ -341,7 +342,7 @@ def _weighted_swiglu_fwd_kernel(
341342
result = silu_y1 * y2_fp32 * weight_fp32
342343

343344
# Store output (cast back to original dtype)
344-
output_offset = token_idx * hidden_size + h_offset
345+
output_offset = token_idx_i64 * hidden_size + h_offset
345346
tl.store(
346347
output_ptr + output_offset + tl.arange(0, BLOCK_SIZE),
347348
result.to(y1.dtype),
@@ -376,8 +377,9 @@ def _weighted_swiglu_bwd_kernel(
376377
# Strided access
377378
token_idx = pid
378379
while token_idx < num_tokens:
380+
token_idx_i64 = token_idx.to(tl.int64)
379381
# Load weight for this token
380-
weight = tl.load(weights_ptr + token_idx)
382+
weight = tl.load(weights_ptr + token_idx_i64)
381383

382384
# Accumulator for weight gradient (fp32 for precision)
383385
weight_grad_acc = 0.0
@@ -387,14 +389,14 @@ def _weighted_swiglu_bwd_kernel(
387389
h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size
388390

389391
# Load grad_output
390-
grad_out_offset = token_idx * hidden_size + h_offset
392+
grad_out_offset = token_idx_i64 * hidden_size + h_offset
391393
grad_out = tl.load(
392394
grad_output_ptr + grad_out_offset + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0
393395
)
394396

395397
# Load input chunks
396-
input_offset_1 = token_idx * (hidden_size * 2) + h_offset
397-
input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset
398+
input_offset_1 = token_idx_i64 * (hidden_size * 2) + h_offset
399+
input_offset_2 = token_idx_i64 * (hidden_size * 2) + hidden_size + h_offset
398400

399401
y1 = tl.load(
400402
input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0
@@ -439,7 +441,7 @@ def _weighted_swiglu_bwd_kernel(
439441
weight_grad_acc += tl.sum(weight_grad_contribution)
440442

441443
# Store weight gradient after processing all chunks
442-
tl.store(grad_weights_ptr + token_idx, weight_grad_acc)
444+
tl.store(grad_weights_ptr + token_idx_i64, weight_grad_acc)
443445

444446
# Stride to next token
445447
token_idx += num_blocks
@@ -471,9 +473,13 @@ def weighted_swiglu_triton(input, weights, num_tokens_tensor):
471473
grid = (num_blocks,)
472474

473475
_weighted_swiglu_fwd_kernel[grid](
474-
input, weights, output, num_tokens_tensor, hidden_size=hidden_size, BLOCK_SIZE=BLOCK_SIZE
476+
input,
477+
weights,
478+
output,
479+
num_tokens_tensor,
480+
hidden_size=hidden_size,
481+
BLOCK_SIZE=BLOCK_SIZE,
475482
)
476-
477483
return output
478484

479485

megatron/core/transformer/moe/experts.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -716,10 +716,18 @@ def forward(
716716
if self.config.moe_paged_stash:
717717
permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states)
718718
if self.moe_paged_stash_expert_fc1:
719+
max_num_tokens = permuted_local_hidden_states.shape[0]
720+
# Average/expected tokens is a pre-padding estimate used by paged stashing heuristics.
721+
# moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled.
722+
cap_factor = self.config.moe_expert_rank_capacity_factor
723+
avg_num_tokens = (
724+
int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None
725+
)
719726
offload_context = get_paged_stash_context(
720727
name="expert_fc1",
721-
max_num_tokens=permuted_local_hidden_states.shape[0],
728+
max_num_tokens=max_num_tokens,
722729
num_tokens_tensor=tokens_per_expert.sum(),
730+
avg_num_tokens=avg_num_tokens,
723731
)
724732
else:
725733
offload_context = nullcontext()
@@ -809,10 +817,18 @@ def glu(x):
809817
else:
810818
with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output:
811819
if self.moe_paged_stash_moe_act:
820+
max_num_tokens = fc1_output.shape[0]
821+
cap_factor = self.config.moe_expert_rank_capacity_factor
822+
avg_num_tokens = (
823+
int(max_num_tokens // cap_factor)
824+
if cap_factor is not None and cap_factor > 0
825+
else None
826+
)
812827
offload_context = get_paged_stash_context(
813828
name="moe_act",
814-
max_num_tokens=fc1_output.shape[0],
829+
max_num_tokens=max_num_tokens,
815830
num_tokens_tensor=tokens_per_expert.sum(),
831+
avg_num_tokens=avg_num_tokens,
816832
)
817833
else:
818834
offload_context = nullcontext()
@@ -824,10 +840,16 @@ def glu(x):
824840
)
825841

826842
if self.moe_paged_stash_expert_fc2:
843+
max_num_tokens = bias_act_output.shape[0]
844+
cap_factor = self.config.moe_expert_rank_capacity_factor
845+
avg_num_tokens = (
846+
int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None
847+
)
827848
offload_context = get_paged_stash_context(
828849
name="expert_fc2",
829-
max_num_tokens=bias_act_output.shape[0],
850+
max_num_tokens=max_num_tokens,
830851
num_tokens_tensor=tokens_per_expert.sum(),
852+
avg_num_tokens=avg_num_tokens,
831853
)
832854
else:
833855
offload_context = nullcontext()

megatron/core/transformer/moe/paged_stash.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,11 @@ def _paged_stash_copy_kernel(
145145
need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0
146146
num_iters = elements_per_thread + (1 if need_mask else 0)
147147

148-
src_base = src_ptr + token_idx * HIDDEN_SIZE
149-
dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE
148+
# Use int64 for address math to avoid int32 overflow when indices get large.
149+
token_idx_i64 = token_idx.to(tl.int64)
150+
dst_token_idx_i64 = dst_token_idx.to(tl.int64)
151+
src_base = src_ptr + token_idx_i64 * HIDDEN_SIZE
152+
dst_base = dst_ptr + dst_token_idx_i64 * HIDDEN_SIZE
150153

151154
if need_mask:
152155
for iter in range(num_iters):
@@ -219,8 +222,11 @@ def _paged_stash_pop_kernel(
219222
need_mask = (HIDDEN_SIZE % BLOCK_SIZE) != 0
220223
num_iters = elements_per_thread + (1 if need_mask else 0)
221224

222-
src_base = src_ptr + src_token_idx * HIDDEN_SIZE
223-
dst_base = dst_ptr + token_idx * HIDDEN_SIZE
225+
# Use int64 for address math to avoid int32 overflow when indices get large.
226+
src_token_idx_i64 = src_token_idx.to(tl.int64)
227+
token_idx_i64 = token_idx.to(tl.int64)
228+
src_base = src_ptr + src_token_idx_i64 * HIDDEN_SIZE
229+
dst_base = dst_ptr + token_idx_i64 * HIDDEN_SIZE
224230

225231
if need_mask:
226232
for iter in range(num_iters):
@@ -261,6 +267,7 @@ def __init__(
261267
self,
262268
tensor,
263269
num_tokens_tensor=None,
270+
avg_num_tokens: int = None,
264271
vp_stage=None,
265272
schedule_layer_no=None,
266273
layer_name=None,
@@ -284,6 +291,7 @@ def __init__(
284291
and num_tokens_tensor.numel() == 1
285292
)
286293
self.num_tokens_tensor = num_tokens_tensor.clone()
294+
self.avg_num_tokens = avg_num_tokens
287295
self.vp_stage = vp_stage
288296
self.schedule_layer_no = schedule_layer_no
289297
self.layer_name = layer_name
@@ -517,7 +525,7 @@ def __init__(self):
517525
"""Initialize the manager with queues and dedicated CUDA streams."""
518526
# allocate streams and events for synchronization
519527
self.enabled = False
520-
self._pack_stream = torch.cuda.Stream()
528+
self._pack_stream = torch.cuda.current_stream()#torch.cuda.Stream()
521529
# Currently paged stashing is not stream-safe, so use the same stream for packing
522530
# and unpacking
523531
self._unpack_stream = self._pack_stream
@@ -543,9 +551,14 @@ def __init__(self):
543551
# Track max tokens needed across all vp_stages grouped by dtype and hidden_size
544552
self.max_tokens_across_vp_stages = None
545553
self.temp_tokens_across_vp_stages = None
554+
# Track max tokens computed from avg_num_tokens (heuristic) across all vp_stages
555+
self.max_avg_tokens_across_vp_stages = None
556+
self.temp_avg_tokens_across_vp_stages = None
546557

547558
self.num_tokens_tensor = None
548559
self.max_num_tokens = None
560+
# Optional hint: expected/average number of tokens (e.g., pre-padding estimate)
561+
self.avg_num_tokens = None
549562
self.stash_buffers = None
550563
self.overflow = None
551564
self.device = None
@@ -663,12 +676,28 @@ def allocate_stash_buffers(self, stash_buffer_size_factor=1.10):
663676
self.stash_buffers = {}
664677
self.overflow = torch.zeros(1, dtype=torch.int64, device=self.device)
665678

666-
for dtype, hidden_size in self.max_tokens_across_vp_stages:
679+
# stash_buffer_size_factor controls both which sizing signal to use and how much headroom
680+
# to allocate:
681+
# - positive: size based on avg_num_tokens-derived maxima
682+
# - negative: size based on actual num_tokens-derived maxima (legacy behavior)
683+
# In both cases we scale by abs(stash_buffer_size_factor).
684+
if stash_buffer_size_factor >= 0:
685+
max_tokens_dict = self.max_avg_tokens_across_vp_stages
686+
scale = stash_buffer_size_factor
687+
else:
688+
max_tokens_dict = self.max_tokens_across_vp_stages
689+
scale = -stash_buffer_size_factor
690+
691+
# Fallback safety: if avg-based dict is not available/populated yet, use actual-max dict.
692+
if not max_tokens_dict:
693+
max_tokens_dict = self.max_tokens_across_vp_stages
694+
695+
for dtype, hidden_size in max_tokens_dict:
667696
if dtype not in self.stash_buffers:
668697
self.stash_buffers[dtype] = {}
669698
assert hidden_size not in self.stash_buffers[dtype]
670699
num_tokens = int(
671-
self.max_tokens_across_vp_stages[dtype, hidden_size] * stash_buffer_size_factor
700+
max_tokens_dict[dtype, hidden_size] * scale
672701
)
673702
self.stash_buffers[dtype][hidden_size] = PagedStashBuffer(
674703
num_tokens, hidden_size, self.page_size, self.device, self.overflow, dtype
@@ -721,9 +750,13 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
721750
tensor._rowwise_data is None
722751
), f"rowwise_data is not None; Only columnwise data is supported for paged stashing"
723752

753+
avg_num_tokens = None
724754
if self.status == 'capture':
725755

726756
self.num_tokens = self.num_tokens_tensor.item()
757+
avg_num_tokens = (
758+
int(self.avg_num_tokens) if self.avg_num_tokens is not None else None
759+
)
727760

728761
dtype = (
729762
tensor.dtype
@@ -743,12 +776,22 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
743776
if (dtype, hidden_size) not in self.temp_tokens_across_vp_stages:
744777
self.temp_tokens_across_vp_stages[dtype, hidden_size] = 0
745778
self.max_tokens_across_vp_stages[dtype, hidden_size] = 0
779+
self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] = 0
780+
self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = 0
746781

747782
self.temp_tokens_across_vp_stages[dtype, hidden_size] += self.num_tokens
748783
self.max_tokens_across_vp_stages[dtype, hidden_size] = max(
749784
self.max_tokens_across_vp_stages[dtype, hidden_size],
750785
self.temp_tokens_across_vp_stages[dtype, hidden_size],
751786
)
787+
788+
# Track avg tokens across vp stages (if provided) using the same accumulation model.
789+
if avg_num_tokens is not None:
790+
self.temp_avg_tokens_across_vp_stages[dtype, hidden_size] += avg_num_tokens
791+
self.max_avg_tokens_across_vp_stages[dtype, hidden_size] = max(
792+
self.max_avg_tokens_across_vp_stages[dtype, hidden_size],
793+
self.temp_avg_tokens_across_vp_stages[dtype, hidden_size],
794+
)
752795
# Since capture stage does not use CUDA graph, we can truncate
753796
# the saved tensor to actual num_tokens
754797
new_size = (self.num_tokens, *tensor.shape[1:])
@@ -767,6 +810,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
767810
paged_tensor = PagedTensor(
768811
tensor,
769812
num_tokens_tensor=self.num_tokens_tensor,
813+
avg_num_tokens=avg_num_tokens,
770814
vp_stage=self.current_vp_stage,
771815
schedule_layer_no=(
772816
self._pp_schedule[self.current_schedule_index]
@@ -791,6 +835,14 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
791835
if isinstance(saved_state, (PagedTensor)):
792836
if self.status == 'capture':
793837
num_tokens = saved_state.num_tokens_tensor.item()
838+
key = (saved_state.dtype, saved_state.hidden_size)
839+
if key in self.temp_tokens_across_vp_stages:
840+
self.temp_tokens_across_vp_stages[key] -= num_tokens
841+
if (
842+
saved_state.avg_num_tokens is not None
843+
and key in self.temp_avg_tokens_across_vp_stages
844+
):
845+
self.temp_avg_tokens_across_vp_stages[key] -= int(saved_state.avg_num_tokens)
794846
# Pad the tensor to the max number of tokens
795847
npad = self.max_num_tokens - num_tokens
796848
pad = ()
@@ -811,6 +863,13 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
811863
assert (
812864
saved_state._tensor is not None
813865
), f"saved_state._tensor is None {saved_state._tensor}"
866+
867+
# Record cross-stream usage (important when tensor was produced on another stream).
868+
if isinstance(saved_state._tensor, MXFP8Tensor):
869+
saved_state._tensor._columnwise_data.record_stream(torch.cuda.current_stream())
870+
elif isinstance(saved_state._tensor, torch.Tensor) and saved_state._tensor.is_cuda:
871+
saved_state._tensor.record_stream(torch.cuda.current_stream())
872+
814873
return saved_state._tensor
815874

816875
return saved_state
@@ -855,12 +914,18 @@ def paged_stash_group_start(tensor):
855914
return PP_PreScheduleFunction.apply(tensor, stash_manager)
856915

857916

858-
def get_paged_stash_context(name=None, max_num_tokens=None, num_tokens_tensor=None):
917+
def get_paged_stash_context(
918+
name=None,
919+
max_num_tokens=None,
920+
num_tokens_tensor=None,
921+
avg_num_tokens=None,
922+
):
859923
"""Get the paged stash context"""
860924
stash_manager = PagedStashManager.get_instance()
861925
if not stash_manager.enabled:
862926
return nullcontext()
863927
stash_manager.max_num_tokens = max_num_tokens
928+
stash_manager.avg_num_tokens = avg_num_tokens
864929
assert num_tokens_tensor is not None and isinstance(num_tokens_tensor, torch.Tensor)
865930
stash_manager.num_tokens_tensor = num_tokens_tensor
866931
stash_manager.set_current_layer_name(name) if name is not None else None
@@ -891,6 +956,8 @@ def paged_stash_init_chunk_handler(vp_size, vp_stage):
891956
if stash_manager.max_tokens_across_vp_stages is None:
892957
stash_manager.max_tokens_across_vp_stages = {}
893958
stash_manager.temp_tokens_across_vp_stages = {}
959+
stash_manager.max_avg_tokens_across_vp_stages = {}
960+
stash_manager.temp_avg_tokens_across_vp_stages = {}
894961

895962

896963
def paged_stash_set_last_layer(is_last_layer=False):

0 commit comments

Comments
 (0)