Skip to content

Commit f733d51

Browse files
committed
Remove d2d page feature for now
Remove unused triton kernel for dropping token in case overflow happens
1 parent 088ea6a commit f733d51

File tree

3 files changed

+54
-196
lines changed

3 files changed

+54
-196
lines changed

megatron/core/transformer/moe/moe_utils.py

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,6 @@
3636
HAVE_TE = False
3737

3838

39-
import triton
40-
import triton.language as tl
41-
42-
43-
4439
# MOE logging
4540
_MOE_LAYER_WISE_LOGGING_TRACKER = {}
4641

@@ -1293,95 +1288,3 @@ def wrapped_func(moe_layer, *args, **kwargs):
12931288
return wrapped_func
12941289

12951290
return decorator
1296-
1297-
@triton.jit
1298-
def _drop_routing_map_kernel(
1299-
routing_map_ptr,
1300-
over_budget_ptr,
1301-
routing_map_dropped_ptr,
1302-
num_elements: tl.constexpr,
1303-
BLOCK_SIZE: tl.constexpr,
1304-
):
1305-
"""Triton kernel to drop routing map based on budget constraints.
1306-
1307-
Args:
1308-
routing_map_ptr: Pointer to the input routing_map tensor
1309-
over_budget_ptr: Pointer to the boolean tensor indicating if any EP rank is over budget
1310-
routing_map_dropped_ptr: Pointer to the output routing_map tensor
1311-
num_elements: Total number of elements to process
1312-
BLOCK_SIZE: Block size for Triton kernel
1313-
"""
1314-
# Get the program ID
1315-
pid = tl.program_id(axis=0)
1316-
1317-
# Read the over_budget value (scalar tensor with single element)
1318-
over_budget_val = tl.load(over_budget_ptr)
1319-
1320-
# Calculate the offset for this program
1321-
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1322-
1323-
# Load the routing_map values
1324-
mask = offset < num_elements
1325-
routing_map_val = tl.load(routing_map_ptr + offset, mask=mask, other=0.0)
1326-
1327-
# If over_budget is 1 (True), output is 0 (drop); if over_budget is 0 (False), output is routing_map_val (keep)
1328-
output_val = routing_map_val * (1 - over_budget_val)
1329-
1330-
# Store the result
1331-
tl.store(routing_map_dropped_ptr + offset, output_val, mask=mask)
1332-
1333-
1334-
def drop_routing_map_triton(
1335-
routing_map: torch.Tensor,
1336-
budget: torch.Tensor,
1337-
num_tokens_per_ep_rank: torch.Tensor
1338-
) -> torch.Tensor:
1339-
"""Drop tokens from routing_map that exceed the budget per EP rank using Triton.
1340-
1341-
Args:
1342-
routing_map: Tensor indicating which tokens are assigned to each expert.
1343-
budget: Integer tensor with the maximum number of tokens per EP rank.
1344-
num_tokens_per_ep_rank: Tensor with actual number of tokens per EP rank.
1345-
1346-
Returns:
1347-
Modified routing_map with tokens exceeding budget zeroed out if any EP rank
1348-
exceeds budget, otherwise returns the original routing_map.
1349-
"""
1350-
1351-
# Calculate boolean tensor: over_budget is True if ANY EP rank exceeds budget
1352-
over_budget = (num_tokens_per_ep_rank > budget).any()
1353-
1354-
# Convert boolean to int8
1355-
over_budget_int = over_budget.to(torch.int8)
1356-
1357-
# Convert routing_map to numeric type if it's boolean
1358-
if routing_map.dtype == torch.bool:
1359-
routing_map_numeric = routing_map.to(torch.int8)
1360-
else:
1361-
routing_map_numeric = routing_map
1362-
1363-
# Create output tensor with same dtype as input
1364-
routing_map_dropped = torch.empty_like(routing_map_numeric)
1365-
1366-
# Flatten tensors for kernel processing
1367-
routing_map_flat = routing_map_numeric.flatten()
1368-
num_elements = routing_map_flat.numel()
1369-
1370-
# Determine grid size
1371-
BLOCK_SIZE = 1024
1372-
grid = (triton.cdiv(num_elements, BLOCK_SIZE),)
1373-
1374-
# Launch kernel with over_budget tensor pointer (as int8)
1375-
_drop_routing_map_kernel[grid](
1376-
routing_map_flat,
1377-
over_budget_int,
1378-
routing_map_dropped.flatten(),
1379-
num_elements,
1380-
BLOCK_SIZE=BLOCK_SIZE,
1381-
)
1382-
1383-
# Convert back to boolean if original was boolean
1384-
if routing_map.dtype == torch.bool:
1385-
routing_map_dropped = routing_map_dropped.to(torch.bool)
1386-
1387-
return routing_map_dropped, over_budget.to(torch.bool)

megatron/core/transformer/moe/paged_stash.py

Lines changed: 54 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class PagedTensor:
247247
A paged tensor that stores data in pages within a paged stash buffer.
248248
"""
249249

250-
def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0):
250+
def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64):
251251
"""
252252
Args:
253253
tensor: The tensor to store
@@ -256,7 +256,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
256256
layer_name: Name of the layer
257257
max_tokens: Maximum number of tokens
258258
page_size: Number of tokens per page
259-
num_d2d_pages: Number of pages to copy using native PyTorch (rest uses Triton)
260259
"""
261260
self._tensor = tensor
262261
self._original_tensor = None
@@ -267,7 +266,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
267266
self.layer_name = layer_name
268267
self.max_tokens = max_tokens
269268
self.page_size = page_size
270-
self.num_d2d_pages = num_d2d_pages
271269

272270
# Original tensor information
273271
self.original_shape = list(tensor.shape)
@@ -282,13 +280,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
282280

283281
# Page record: stores which pages are being used for this tensor
284282
self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device)
285-
286-
# Static tensor for D2D pages (allocate upfront if needed)
287-
d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens)
288-
if d2d_tokens > 0:
289-
self.static_tensor = torch.empty((d2d_tokens, self.hidden_size), dtype=self.dtype, device=self.device)
290-
else:
291-
self.static_tensor = None
292283

293284
@property
294285
def schedule_layer(self):
@@ -312,48 +303,33 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048
312303
else:
313304
tensor_to_copy = self._tensor
314305

315-
# Split tensor into two parts: D2D portion and Triton portion
316-
# Use max_num_tokens for consistent size across iterations
317-
d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens)
318-
triton_tokens = self.max_num_tokens - d2d_tokens
319-
320-
# Perform both D2D copy and Triton kernel together
321-
# Part 1: Copy first d2d_tokens to static_tensor using native PyTorch
322-
if d2d_tokens > 0:
323-
self.static_tensor[:d2d_tokens] = tensor_to_copy[:d2d_tokens]
324-
# Part 2: Copy remaining tokens using Triton kernel
325-
if triton_tokens > 0:
326-
triton_tensor = tensor_to_copy[d2d_tokens:self.max_num_tokens]
327-
# Use actual num_tokens for the kernel (how many tokens to actually copy)
328-
triton_num_tokens = self.num_tokens_tensor - d2d_tokens
329-
330-
# Determine grid size
331-
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
332-
num_blocks = min(triton_tokens, max_blocks)
333-
grid = (num_blocks,)
334-
335-
# Create temporary tensor for new head
336-
new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device)
337-
338-
# Launch paged stash copy kernel
339-
_paged_stash_copy_kernel[grid](
340-
triton_tensor,
341-
paged_stash_buffer.buffer,
342-
triton_num_tokens,
343-
paged_stash_buffer.free_list,
344-
paged_stash_buffer.free_list_head,
345-
paged_stash_buffer.free_list_tail,
346-
paged_stash_buffer.free_list_capacity,
347-
self.page_record, # Triton kernel will populate page_record
348-
paged_stash_buffer.overflow,
349-
new_free_list_head,
350-
PAGE_SIZE=self.page_size,
351-
HIDDEN_SIZE=self.hidden_size,
352-
BLOCK_SIZE=BLOCK_SIZE,
353-
)
354-
355-
# Update free list head
356-
paged_stash_buffer.free_list_head.copy_(new_free_list_head)
306+
# Determine grid size
307+
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
308+
num_blocks = min(self.max_num_tokens, max_blocks)
309+
grid = (num_blocks,)
310+
311+
# Create temporary tensor for new head
312+
new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device)
313+
314+
# Launch paged stash copy kernel
315+
_paged_stash_copy_kernel[grid](
316+
tensor_to_copy,
317+
paged_stash_buffer.buffer,
318+
self.num_tokens_tensor,
319+
paged_stash_buffer.free_list,
320+
paged_stash_buffer.free_list_head,
321+
paged_stash_buffer.free_list_tail,
322+
paged_stash_buffer.free_list_capacity,
323+
self.page_record, # Triton kernel will populate page_record
324+
paged_stash_buffer.overflow,
325+
new_free_list_head,
326+
PAGE_SIZE=self.page_size,
327+
HIDDEN_SIZE=self.hidden_size,
328+
BLOCK_SIZE=BLOCK_SIZE,
329+
)
330+
331+
# Update free list head
332+
paged_stash_buffer.free_list_head.copy_(new_free_list_head)
357333

358334
# Save reference to original tensor
359335
self._original_tensor = self._tensor
@@ -384,48 +360,32 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204
384360
self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device)
385361
tensor_to_reload = self._tensor
386362

387-
# Split tensor into two parts: D2D portion and Triton portion
388-
# Use max_num_tokens for consistency with stash
389-
d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens)
390-
triton_tokens = self.max_num_tokens - d2d_tokens
391-
392-
# Perform both D2D copy and Triton kernel together
393-
# Part 1: Copy first d2d_tokens from static_tensor using native PyTorch
394-
if d2d_tokens > 0 and self.static_tensor is not None:
395-
tensor_to_reload[:d2d_tokens] = self.static_tensor[:d2d_tokens]
396-
397-
# Part 2: Copy remaining tokens using Triton kernel
398-
if triton_tokens > 0:
399-
triton_tensor = tensor_to_reload[d2d_tokens:self.max_num_tokens]
400-
# Use actual num_tokens for the kernel (how many tokens to actually copy)
401-
triton_num_tokens = self.num_tokens_tensor - d2d_tokens
402-
403-
# Determine grid size
404-
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
405-
num_blocks = min(triton_tokens, max_blocks)
406-
grid = (num_blocks,)
407-
408-
# Create temporary tensor for new tail
409-
new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device)
410-
411-
# Launch paged stash pop kernel
412-
_paged_stash_pop_kernel[grid](
413-
paged_stash_buffer.buffer,
414-
triton_tensor,
415-
triton_num_tokens,
416-
self.page_record, # Triton kernel will read from page_record
417-
paged_stash_buffer.free_list,
418-
paged_stash_buffer.free_list_head,
419-
paged_stash_buffer.free_list_tail,
420-
paged_stash_buffer.free_list_capacity,
421-
new_free_list_tail,
422-
PAGE_SIZE=self.page_size,
423-
HIDDEN_SIZE=self.hidden_size,
424-
BLOCK_SIZE=BLOCK_SIZE,
425-
)
426-
427-
# Update free list tail
428-
paged_stash_buffer.free_list_tail.copy_(new_free_list_tail)
363+
# Determine grid size
364+
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
365+
num_blocks = min(self.max_num_tokens, max_blocks)
366+
grid = (num_blocks,)
367+
368+
# Create temporary tensor for new tail
369+
new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device)
370+
371+
# Launch paged stash pop kernel
372+
_paged_stash_pop_kernel[grid](
373+
paged_stash_buffer.buffer,
374+
tensor_to_reload,
375+
self.num_tokens_tensor,
376+
self.page_record, # Triton kernel will read from page_record
377+
paged_stash_buffer.free_list,
378+
paged_stash_buffer.free_list_head,
379+
paged_stash_buffer.free_list_tail,
380+
paged_stash_buffer.free_list_capacity,
381+
new_free_list_tail,
382+
PAGE_SIZE=self.page_size,
383+
HIDDEN_SIZE=self.hidden_size,
384+
BLOCK_SIZE=BLOCK_SIZE,
385+
)
386+
387+
# Update free list tail
388+
paged_stash_buffer.free_list_tail.copy_(new_free_list_tail)
429389

430390

431391
class PP_PreScheduleFunction(torch.autograd.Function):
@@ -555,9 +515,6 @@ def __init__(self):
555515

556516
# Page size for paged memory management
557517
self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page
558-
559-
# Number of pages to copy using native PyTorch (D2D)
560-
self.num_d2d_pages = int(os.getenv('NUM_D2D_PAGES', '0')) # Default 0 (all Triton)
561518

562519
@property
563520
def pack_stream(self):
@@ -765,7 +722,6 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
765722
layer_name=self._current_layer_name,
766723
max_tokens=self.max_num_tokens,
767724
page_size=self.page_size,
768-
num_d2d_pages=self.num_d2d_pages
769725
)
770726

771727
if self.status == 'captured':

megatron/core/transformer/moe/token_dispatcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
permute,
3434
sort_chunks_by_idxs,
3535
unpermute,
36-
drop_routing_map_triton,
3736
)
3837
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
3938
from megatron.core.transformer.transformer_config import TransformerConfig

0 commit comments

Comments
 (0)