@@ -247,7 +247,7 @@ class PagedTensor:
247247 A paged tensor that stores data in pages within a paged stash buffer.
248248 """
249249
250- def __init__ (self , tensor , num_tokens_tensor = None , vp_stage = None , schedule_layer_no = None , layer_name = None , max_tokens = None , page_size = 64 , num_d2d_pages = 0 ):
250+ def __init__ (self , tensor , num_tokens_tensor = None , vp_stage = None , schedule_layer_no = None , layer_name = None , max_tokens = None , page_size = 64 ):
251251 """
252252 Args:
253253 tensor: The tensor to store
@@ -256,7 +256,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
256256 layer_name: Name of the layer
257257 max_tokens: Maximum number of tokens
258258 page_size: Number of tokens per page
259- num_d2d_pages: Number of pages to copy using native PyTorch (rest uses Triton)
260259 """
261260 self ._tensor = tensor
262261 self ._original_tensor = None
@@ -267,7 +266,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
267266 self .layer_name = layer_name
268267 self .max_tokens = max_tokens
269268 self .page_size = page_size
270- self .num_d2d_pages = num_d2d_pages
271269
272270 # Original tensor information
273271 self .original_shape = list (tensor .shape )
@@ -282,13 +280,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
282280
283281 # Page record: stores which pages are being used for this tensor
284282 self .page_record = torch .zeros (self .max_num_pages , dtype = torch .int64 , device = self .device )
285-
286- # Static tensor for D2D pages (allocate upfront if needed)
287- d2d_tokens = min (self .num_d2d_pages * self .page_size , self .max_num_tokens )
288- if d2d_tokens > 0 :
289- self .static_tensor = torch .empty ((d2d_tokens , self .hidden_size ), dtype = self .dtype , device = self .device )
290- else :
291- self .static_tensor = None
292283
293284 @property
294285 def schedule_layer (self ):
@@ -312,48 +303,33 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048
312303 else :
313304 tensor_to_copy = self ._tensor
314305
315- # Split tensor into two parts: D2D portion and Triton portion
316- # Use max_num_tokens for consistent size across iterations
317- d2d_tokens = min (self .num_d2d_pages * self .page_size , self .max_num_tokens )
318- triton_tokens = self .max_num_tokens - d2d_tokens
319-
320- # Perform both D2D copy and Triton kernel together
321- # Part 1: Copy first d2d_tokens to static_tensor using native PyTorch
322- if d2d_tokens > 0 :
323- self .static_tensor [:d2d_tokens ] = tensor_to_copy [:d2d_tokens ]
324- # Part 2: Copy remaining tokens using Triton kernel
325- if triton_tokens > 0 :
326- triton_tensor = tensor_to_copy [d2d_tokens :self .max_num_tokens ]
327- # Use actual num_tokens for the kernel (how many tokens to actually copy)
328- triton_num_tokens = self .num_tokens_tensor - d2d_tokens
329-
330- # Determine grid size
331- BLOCK_SIZE = GLOBAL_BLOCK_SIZE
332- num_blocks = min (triton_tokens , max_blocks )
333- grid = (num_blocks ,)
334-
335- # Create temporary tensor for new head
336- new_free_list_head = torch .empty (1 , dtype = torch .int64 , device = self .device )
337-
338- # Launch paged stash copy kernel
339- _paged_stash_copy_kernel [grid ](
340- triton_tensor ,
341- paged_stash_buffer .buffer ,
342- triton_num_tokens ,
343- paged_stash_buffer .free_list ,
344- paged_stash_buffer .free_list_head ,
345- paged_stash_buffer .free_list_tail ,
346- paged_stash_buffer .free_list_capacity ,
347- self .page_record , # Triton kernel will populate page_record
348- paged_stash_buffer .overflow ,
349- new_free_list_head ,
350- PAGE_SIZE = self .page_size ,
351- HIDDEN_SIZE = self .hidden_size ,
352- BLOCK_SIZE = BLOCK_SIZE ,
353- )
354-
355- # Update free list head
356- paged_stash_buffer .free_list_head .copy_ (new_free_list_head )
306+ # Determine grid size
307+ BLOCK_SIZE = GLOBAL_BLOCK_SIZE
308+ num_blocks = min (self .max_num_tokens , max_blocks )
309+ grid = (num_blocks ,)
310+
311+ # Create temporary tensor for new head
312+ new_free_list_head = torch .empty (1 , dtype = torch .int64 , device = self .device )
313+
314+ # Launch paged stash copy kernel
315+ _paged_stash_copy_kernel [grid ](
316+ tensor_to_copy ,
317+ paged_stash_buffer .buffer ,
318+ self .num_tokens_tensor ,
319+ paged_stash_buffer .free_list ,
320+ paged_stash_buffer .free_list_head ,
321+ paged_stash_buffer .free_list_tail ,
322+ paged_stash_buffer .free_list_capacity ,
323+ self .page_record , # Triton kernel will populate page_record
324+ paged_stash_buffer .overflow ,
325+ new_free_list_head ,
326+ PAGE_SIZE = self .page_size ,
327+ HIDDEN_SIZE = self .hidden_size ,
328+ BLOCK_SIZE = BLOCK_SIZE ,
329+ )
330+
331+ # Update free list head
332+ paged_stash_buffer .free_list_head .copy_ (new_free_list_head )
357333
358334 # Save reference to original tensor
359335 self ._original_tensor = self ._tensor
@@ -384,48 +360,32 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204
384360 self ._tensor = torch .empty (self .original_shape , dtype = self .dtype , device = self .device )
385361 tensor_to_reload = self ._tensor
386362
387- # Split tensor into two parts: D2D portion and Triton portion
388- # Use max_num_tokens for consistency with stash
389- d2d_tokens = min (self .num_d2d_pages * self .page_size , self .max_num_tokens )
390- triton_tokens = self .max_num_tokens - d2d_tokens
391-
392- # Perform both D2D copy and Triton kernel together
393- # Part 1: Copy first d2d_tokens from static_tensor using native PyTorch
394- if d2d_tokens > 0 and self .static_tensor is not None :
395- tensor_to_reload [:d2d_tokens ] = self .static_tensor [:d2d_tokens ]
396-
397- # Part 2: Copy remaining tokens using Triton kernel
398- if triton_tokens > 0 :
399- triton_tensor = tensor_to_reload [d2d_tokens :self .max_num_tokens ]
400- # Use actual num_tokens for the kernel (how many tokens to actually copy)
401- triton_num_tokens = self .num_tokens_tensor - d2d_tokens
402-
403- # Determine grid size
404- BLOCK_SIZE = GLOBAL_BLOCK_SIZE
405- num_blocks = min (triton_tokens , max_blocks )
406- grid = (num_blocks ,)
407-
408- # Create temporary tensor for new tail
409- new_free_list_tail = torch .empty (1 , dtype = torch .int64 , device = self .device )
410-
411- # Launch paged stash pop kernel
412- _paged_stash_pop_kernel [grid ](
413- paged_stash_buffer .buffer ,
414- triton_tensor ,
415- triton_num_tokens ,
416- self .page_record , # Triton kernel will read from page_record
417- paged_stash_buffer .free_list ,
418- paged_stash_buffer .free_list_head ,
419- paged_stash_buffer .free_list_tail ,
420- paged_stash_buffer .free_list_capacity ,
421- new_free_list_tail ,
422- PAGE_SIZE = self .page_size ,
423- HIDDEN_SIZE = self .hidden_size ,
424- BLOCK_SIZE = BLOCK_SIZE ,
425- )
426-
427- # Update free list tail
428- paged_stash_buffer .free_list_tail .copy_ (new_free_list_tail )
363+ # Determine grid size
364+ BLOCK_SIZE = GLOBAL_BLOCK_SIZE
365+ num_blocks = min (self .max_num_tokens , max_blocks )
366+ grid = (num_blocks ,)
367+
368+ # Create temporary tensor for new tail
369+ new_free_list_tail = torch .empty (1 , dtype = torch .int64 , device = self .device )
370+
371+ # Launch paged stash pop kernel
372+ _paged_stash_pop_kernel [grid ](
373+ paged_stash_buffer .buffer ,
374+ tensor_to_reload ,
375+ self .num_tokens_tensor ,
376+ self .page_record , # Triton kernel will read from page_record
377+ paged_stash_buffer .free_list ,
378+ paged_stash_buffer .free_list_head ,
379+ paged_stash_buffer .free_list_tail ,
380+ paged_stash_buffer .free_list_capacity ,
381+ new_free_list_tail ,
382+ PAGE_SIZE = self .page_size ,
383+ HIDDEN_SIZE = self .hidden_size ,
384+ BLOCK_SIZE = BLOCK_SIZE ,
385+ )
386+
387+ # Update free list tail
388+ paged_stash_buffer .free_list_tail .copy_ (new_free_list_tail )
429389
430390
431391class PP_PreScheduleFunction (torch .autograd .Function ):
@@ -555,9 +515,6 @@ def __init__(self):
555515
556516 # Page size for paged memory management
557517 self .page_size = int (os .getenv ('PAGED_STASH_PAGE_SIZE' , '64' )) # Default 64 tokens per page
558-
559- # Number of pages to copy using native PyTorch (D2D)
560- self .num_d2d_pages = int (os .getenv ('NUM_D2D_PAGES' , '0' )) # Default 0 (all Triton)
561518
562519 @property
563520 def pack_stream (self ):
@@ -765,7 +722,6 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
765722 layer_name = self ._current_layer_name ,
766723 max_tokens = self .max_num_tokens ,
767724 page_size = self .page_size ,
768- num_d2d_pages = self .num_d2d_pages
769725 )
770726
771727 if self .status == 'captured' :
0 commit comments