@@ -145,8 +145,11 @@ def _paged_stash_copy_kernel(
145145 need_mask = (HIDDEN_SIZE % BLOCK_SIZE ) != 0
146146 num_iters = elements_per_thread + (1 if need_mask else 0 )
147147
148- src_base = src_ptr + token_idx * HIDDEN_SIZE
149- dst_base = dst_ptr + dst_token_idx * HIDDEN_SIZE
148+ # Use int64 for address math to avoid int32 overflow when indices get large.
149+ token_idx_i64 = token_idx .to (tl .int64 )
150+ dst_token_idx_i64 = dst_token_idx .to (tl .int64 )
151+ src_base = src_ptr + token_idx_i64 * HIDDEN_SIZE
152+ dst_base = dst_ptr + dst_token_idx_i64 * HIDDEN_SIZE
150153
151154 if need_mask :
152155 for iter in range (num_iters ):
@@ -219,8 +222,11 @@ def _paged_stash_pop_kernel(
219222 need_mask = (HIDDEN_SIZE % BLOCK_SIZE ) != 0
220223 num_iters = elements_per_thread + (1 if need_mask else 0 )
221224
222- src_base = src_ptr + src_token_idx * HIDDEN_SIZE
223- dst_base = dst_ptr + token_idx * HIDDEN_SIZE
225+ # Use int64 for address math to avoid int32 overflow when indices get large.
226+ src_token_idx_i64 = src_token_idx .to (tl .int64 )
227+ token_idx_i64 = token_idx .to (tl .int64 )
228+ src_base = src_ptr + src_token_idx_i64 * HIDDEN_SIZE
229+ dst_base = dst_ptr + token_idx_i64 * HIDDEN_SIZE
224230
225231 if need_mask :
226232 for iter in range (num_iters ):
@@ -261,6 +267,7 @@ def __init__(
261267 self ,
262268 tensor ,
263269 num_tokens_tensor = None ,
270+ avg_num_tokens : int = None ,
264271 vp_stage = None ,
265272 schedule_layer_no = None ,
266273 layer_name = None ,
@@ -284,6 +291,7 @@ def __init__(
284291 and num_tokens_tensor .numel () == 1
285292 )
286293 self .num_tokens_tensor = num_tokens_tensor .clone ()
294+ self .avg_num_tokens = avg_num_tokens
287295 self .vp_stage = vp_stage
288296 self .schedule_layer_no = schedule_layer_no
289297 self .layer_name = layer_name
@@ -517,7 +525,7 @@ def __init__(self):
517525 """Initialize the manager with queues and dedicated CUDA streams."""
518526 # allocate streams and events for synchronization
519527 self .enabled = False
520- self ._pack_stream = torch .cuda .Stream ()
528+ self ._pack_stream = torch .cuda .current_stream () #torch.cuda. Stream()
521529 # Currently paged stashing is not stream-safe, so use the same stream for packing
522530 # and unpacking
523531 self ._unpack_stream = self ._pack_stream
@@ -543,9 +551,14 @@ def __init__(self):
543551 # Track max tokens needed across all vp_stages grouped by dtype and hidden_size
544552 self .max_tokens_across_vp_stages = None
545553 self .temp_tokens_across_vp_stages = None
554+ # Track max tokens computed from avg_num_tokens (heuristic) across all vp_stages
555+ self .max_avg_tokens_across_vp_stages = None
556+ self .temp_avg_tokens_across_vp_stages = None
546557
547558 self .num_tokens_tensor = None
548559 self .max_num_tokens = None
560+ # Optional hint: expected/average number of tokens (e.g., pre-padding estimate)
561+ self .avg_num_tokens = None
549562 self .stash_buffers = None
550563 self .overflow = None
551564 self .device = None
@@ -663,12 +676,28 @@ def allocate_stash_buffers(self, stash_buffer_size_factor=1.10):
663676 self .stash_buffers = {}
664677 self .overflow = torch .zeros (1 , dtype = torch .int64 , device = self .device )
665678
666- for dtype , hidden_size in self .max_tokens_across_vp_stages :
679+ # stash_buffer_size_factor controls both which sizing signal to use and how much headroom
680+ # to allocate:
681+ # - positive: size based on avg_num_tokens-derived maxima
682+ # - negative: size based on actual num_tokens-derived maxima (legacy behavior)
683+ # In both cases we scale by abs(stash_buffer_size_factor).
684+ if stash_buffer_size_factor >= 0 :
685+ max_tokens_dict = self .max_avg_tokens_across_vp_stages
686+ scale = stash_buffer_size_factor
687+ else :
688+ max_tokens_dict = self .max_tokens_across_vp_stages
689+ scale = - stash_buffer_size_factor
690+
691+ # Fallback safety: if avg-based dict is not available/populated yet, use actual-max dict.
692+ if not max_tokens_dict :
693+ max_tokens_dict = self .max_tokens_across_vp_stages
694+
695+ for dtype , hidden_size in max_tokens_dict :
667696 if dtype not in self .stash_buffers :
668697 self .stash_buffers [dtype ] = {}
669698 assert hidden_size not in self .stash_buffers [dtype ]
670699 num_tokens = int (
671- self . max_tokens_across_vp_stages [dtype , hidden_size ] * stash_buffer_size_factor
700+ max_tokens_dict [dtype , hidden_size ] * scale
672701 )
673702 self .stash_buffers [dtype ][hidden_size ] = PagedStashBuffer (
674703 num_tokens , hidden_size , self .page_size , self .device , self .overflow , dtype
@@ -721,9 +750,13 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
721750 tensor ._rowwise_data is None
722751 ), f"rowwise_data is not None; Only columnwise data is supported for paged stashing"
723752
753+ avg_num_tokens = None
724754 if self .status == 'capture' :
725755
726756 self .num_tokens = self .num_tokens_tensor .item ()
757+ avg_num_tokens = (
758+ int (self .avg_num_tokens ) if self .avg_num_tokens is not None else None
759+ )
727760
728761 dtype = (
729762 tensor .dtype
@@ -743,12 +776,22 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
743776 if (dtype , hidden_size ) not in self .temp_tokens_across_vp_stages :
744777 self .temp_tokens_across_vp_stages [dtype , hidden_size ] = 0
745778 self .max_tokens_across_vp_stages [dtype , hidden_size ] = 0
779+ self .temp_avg_tokens_across_vp_stages [dtype , hidden_size ] = 0
780+ self .max_avg_tokens_across_vp_stages [dtype , hidden_size ] = 0
746781
747782 self .temp_tokens_across_vp_stages [dtype , hidden_size ] += self .num_tokens
748783 self .max_tokens_across_vp_stages [dtype , hidden_size ] = max (
749784 self .max_tokens_across_vp_stages [dtype , hidden_size ],
750785 self .temp_tokens_across_vp_stages [dtype , hidden_size ],
751786 )
787+
788+ # Track avg tokens across vp stages (if provided) using the same accumulation model.
789+ if avg_num_tokens is not None :
790+ self .temp_avg_tokens_across_vp_stages [dtype , hidden_size ] += avg_num_tokens
791+ self .max_avg_tokens_across_vp_stages [dtype , hidden_size ] = max (
792+ self .max_avg_tokens_across_vp_stages [dtype , hidden_size ],
793+ self .temp_avg_tokens_across_vp_stages [dtype , hidden_size ],
794+ )
752795 # Since capture stage does not use CUDA graph, we can truncate
753796 # the saved tensor to actual num_tokens
754797 new_size = (self .num_tokens , * tensor .shape [1 :])
@@ -767,6 +810,7 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
767810 paged_tensor = PagedTensor (
768811 tensor ,
769812 num_tokens_tensor = self .num_tokens_tensor ,
813+ avg_num_tokens = avg_num_tokens ,
770814 vp_stage = self .current_vp_stage ,
771815 schedule_layer_no = (
772816 self ._pp_schedule [self .current_schedule_index ]
@@ -791,6 +835,14 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
791835 if isinstance (saved_state , (PagedTensor )):
792836 if self .status == 'capture' :
793837 num_tokens = saved_state .num_tokens_tensor .item ()
838+ key = (saved_state .dtype , saved_state .hidden_size )
839+ if key in self .temp_tokens_across_vp_stages :
840+ self .temp_tokens_across_vp_stages [key ] -= num_tokens
841+ if (
842+ saved_state .avg_num_tokens is not None
843+ and key in self .temp_avg_tokens_across_vp_stages
844+ ):
845+ self .temp_avg_tokens_across_vp_stages [key ] -= int (saved_state .avg_num_tokens )
794846 # Pad the tensor to the max number of tokens
795847 npad = self .max_num_tokens - num_tokens
796848 pad = ()
@@ -811,6 +863,13 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
811863 assert (
812864 saved_state ._tensor is not None
813865 ), f"saved_state._tensor is None { saved_state ._tensor } "
866+
867+ # Record cross-stream usage (important when tensor was produced on another stream).
868+ if isinstance (saved_state ._tensor , MXFP8Tensor ):
869+ saved_state ._tensor ._columnwise_data .record_stream (torch .cuda .current_stream ())
870+ elif isinstance (saved_state ._tensor , torch .Tensor ) and saved_state ._tensor .is_cuda :
871+ saved_state ._tensor .record_stream (torch .cuda .current_stream ())
872+
814873 return saved_state ._tensor
815874
816875 return saved_state
@@ -855,12 +914,18 @@ def paged_stash_group_start(tensor):
855914 return PP_PreScheduleFunction .apply (tensor , stash_manager )
856915
857916
858- def get_paged_stash_context (name = None , max_num_tokens = None , num_tokens_tensor = None ):
917+ def get_paged_stash_context (
918+ name = None ,
919+ max_num_tokens = None ,
920+ num_tokens_tensor = None ,
921+ avg_num_tokens = None ,
922+ ):
859923 """Get the paged stash context"""
860924 stash_manager = PagedStashManager .get_instance ()
861925 if not stash_manager .enabled :
862926 return nullcontext ()
863927 stash_manager .max_num_tokens = max_num_tokens
928+ stash_manager .avg_num_tokens = avg_num_tokens
864929 assert num_tokens_tensor is not None and isinstance (num_tokens_tensor , torch .Tensor )
865930 stash_manager .num_tokens_tensor = num_tokens_tensor
866931 stash_manager .set_current_layer_name (name ) if name is not None else None
@@ -891,6 +956,8 @@ def paged_stash_init_chunk_handler(vp_size, vp_stage):
891956 if stash_manager .max_tokens_across_vp_stages is None :
892957 stash_manager .max_tokens_across_vp_stages = {}
893958 stash_manager .temp_tokens_across_vp_stages = {}
959+ stash_manager .max_avg_tokens_across_vp_stages = {}
960+ stash_manager .temp_avg_tokens_across_vp_stages = {}
894961
895962
896963def paged_stash_set_last_layer (is_last_layer = False ):
0 commit comments