Skip to content

Commit d99b74f

Browse files
committed
Remove d2d page feature for now
Remove unused triton kernel for dropping token in case overflow happens
1 parent 3db006b commit d99b74f

File tree

3 files changed

+55
-197
lines changed

3 files changed

+55
-197
lines changed

megatron/core/transformer/moe/moe_utils.py

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@
3434
HAVE_TE = False
3535

3636

37-
import triton
38-
import triton.language as tl
39-
40-
41-
4237
# MOE logging
4338
_MOE_LAYER_WISE_LOGGING_TRACKER = {}
4439

@@ -934,7 +929,7 @@ def forward(ctx, logits):
934929
"""
935930
Forward pass returns random logits with rank-specific seed.
936931
"""
937-
if RandomSTE.random_logits is not None:
932+
if is_graph_capturing() and RandomSTE.random_logits is not None:
938933
return RandomSTE.random_logits
939934

940935
if RandomSTE.generator is None:
@@ -1302,95 +1297,3 @@ def wrapped_func(moe_layer, *args, **kwargs):
13021297
return wrapped_func
13031298

13041299
return decorator
1305-
1306-
@triton.jit
1307-
def _drop_routing_map_kernel(
1308-
routing_map_ptr,
1309-
over_budget_ptr,
1310-
routing_map_dropped_ptr,
1311-
num_elements: tl.constexpr,
1312-
BLOCK_SIZE: tl.constexpr,
1313-
):
1314-
"""Triton kernel to drop routing map based on budget constraints.
1315-
1316-
Args:
1317-
routing_map_ptr: Pointer to the input routing_map tensor
1318-
over_budget_ptr: Pointer to the boolean tensor indicating if any EP rank is over budget
1319-
routing_map_dropped_ptr: Pointer to the output routing_map tensor
1320-
num_elements: Total number of elements to process
1321-
BLOCK_SIZE: Block size for Triton kernel
1322-
"""
1323-
# Get the program ID
1324-
pid = tl.program_id(axis=0)
1325-
1326-
# Read the over_budget value (scalar tensor with single element)
1327-
over_budget_val = tl.load(over_budget_ptr)
1328-
1329-
# Calculate the offset for this program
1330-
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1331-
1332-
# Load the routing_map values
1333-
mask = offset < num_elements
1334-
routing_map_val = tl.load(routing_map_ptr + offset, mask=mask, other=0.0)
1335-
1336-
# If over_budget is 1 (True), output is 0 (drop); if over_budget is 0 (False), output is routing_map_val (keep)
1337-
output_val = routing_map_val * (1 - over_budget_val)
1338-
1339-
# Store the result
1340-
tl.store(routing_map_dropped_ptr + offset, output_val, mask=mask)
1341-
1342-
1343-
def drop_routing_map_triton(
1344-
routing_map: torch.Tensor,
1345-
budget: torch.Tensor,
1346-
num_tokens_per_ep_rank: torch.Tensor
1347-
) -> torch.Tensor:
1348-
"""Drop tokens from routing_map that exceed the budget per EP rank using Triton.
1349-
1350-
Args:
1351-
routing_map: Tensor indicating which tokens are assigned to each expert.
1352-
budget: Integer tensor with the maximum number of tokens per EP rank.
1353-
num_tokens_per_ep_rank: Tensor with actual number of tokens per EP rank.
1354-
1355-
Returns:
1356-
Modified routing_map with tokens exceeding budget zeroed out if any EP rank
1357-
exceeds budget, otherwise returns the original routing_map.
1358-
"""
1359-
1360-
# Calculate boolean tensor: over_budget is True if ANY EP rank exceeds budget
1361-
over_budget = (num_tokens_per_ep_rank > budget).any()
1362-
1363-
# Convert boolean to int8
1364-
over_budget_int = over_budget.to(torch.int8)
1365-
1366-
# Convert routing_map to numeric type if it's boolean
1367-
if routing_map.dtype == torch.bool:
1368-
routing_map_numeric = routing_map.to(torch.int8)
1369-
else:
1370-
routing_map_numeric = routing_map
1371-
1372-
# Create output tensor with same dtype as input
1373-
routing_map_dropped = torch.empty_like(routing_map_numeric)
1374-
1375-
# Flatten tensors for kernel processing
1376-
routing_map_flat = routing_map_numeric.flatten()
1377-
num_elements = routing_map_flat.numel()
1378-
1379-
# Determine grid size
1380-
BLOCK_SIZE = 1024
1381-
grid = (triton.cdiv(num_elements, BLOCK_SIZE),)
1382-
1383-
# Launch kernel with over_budget tensor pointer (as int8)
1384-
_drop_routing_map_kernel[grid](
1385-
routing_map_flat,
1386-
over_budget_int,
1387-
routing_map_dropped.flatten(),
1388-
num_elements,
1389-
BLOCK_SIZE=BLOCK_SIZE,
1390-
)
1391-
1392-
# Convert back to boolean if original was boolean
1393-
if routing_map.dtype == torch.bool:
1394-
routing_map_dropped = routing_map_dropped.to(torch.bool)
1395-
1396-
return routing_map_dropped, over_budget.to(torch.bool)

megatron/core/transformer/moe/paged_stash.py

Lines changed: 54 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class PagedTensor:
247247
A paged tensor that stores data in pages within a paged stash buffer.
248248
"""
249249

250-
def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64, num_d2d_pages=0):
250+
def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer_no=None, layer_name=None, max_tokens=None, page_size=64):
251251
"""
252252
Args:
253253
tensor: The tensor to store
@@ -256,7 +256,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
256256
layer_name: Name of the layer
257257
max_tokens: Maximum number of tokens
258258
page_size: Number of tokens per page
259-
num_d2d_pages: Number of pages to copy using native PyTorch (rest uses Triton)
260259
"""
261260
self._tensor = tensor
262261
self._original_tensor = None
@@ -267,7 +266,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
267266
self.layer_name = layer_name
268267
self.max_tokens = max_tokens
269268
self.page_size = page_size
270-
self.num_d2d_pages = num_d2d_pages
271269

272270
# Original tensor information
273271
self.original_shape = list(tensor.shape)
@@ -282,13 +280,6 @@ def __init__(self, tensor, num_tokens_tensor=None, vp_stage=None, schedule_layer
282280

283281
# Page record: stores which pages are being used for this tensor
284282
self.page_record = torch.zeros(self.max_num_pages, dtype=torch.int64, device=self.device)
285-
286-
# Static tensor for D2D pages (allocate upfront if needed)
287-
d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens)
288-
if d2d_tokens > 0:
289-
self.static_tensor = torch.empty((d2d_tokens, self.hidden_size), dtype=self.dtype, device=self.device)
290-
else:
291-
self.static_tensor = None
292283

293284
@property
294285
def schedule_layer(self):
@@ -312,48 +303,33 @@ def offload_to_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=2048
312303
else:
313304
tensor_to_copy = self._tensor
314305

315-
# Split tensor into two parts: D2D portion and Triton portion
316-
# Use max_num_tokens for consistent size across iterations
317-
d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens)
318-
triton_tokens = self.max_num_tokens - d2d_tokens
319-
320-
# Perform both D2D copy and Triton kernel together
321-
# Part 1: Copy first d2d_tokens to static_tensor using native PyTorch
322-
if d2d_tokens > 0:
323-
self.static_tensor[:d2d_tokens] = tensor_to_copy[:d2d_tokens]
324-
# Part 2: Copy remaining tokens using Triton kernel
325-
if triton_tokens > 0:
326-
triton_tensor = tensor_to_copy[d2d_tokens:self.max_num_tokens]
327-
# Use actual num_tokens for the kernel (how many tokens to actually copy)
328-
triton_num_tokens = self.num_tokens_tensor - d2d_tokens
329-
330-
# Determine grid size
331-
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
332-
num_blocks = min(triton_tokens, max_blocks)
333-
grid = (num_blocks,)
334-
335-
# Create temporary tensor for new head
336-
new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device)
337-
338-
# Launch paged stash copy kernel
339-
_paged_stash_copy_kernel[grid](
340-
triton_tensor,
341-
paged_stash_buffer.buffer,
342-
triton_num_tokens,
343-
paged_stash_buffer.free_list,
344-
paged_stash_buffer.free_list_head,
345-
paged_stash_buffer.free_list_tail,
346-
paged_stash_buffer.free_list_capacity,
347-
self.page_record, # Triton kernel will populate page_record
348-
paged_stash_buffer.overflow,
349-
new_free_list_head,
350-
PAGE_SIZE=self.page_size,
351-
HIDDEN_SIZE=self.hidden_size,
352-
BLOCK_SIZE=BLOCK_SIZE,
353-
)
354-
355-
# Update free list head
356-
paged_stash_buffer.free_list_head.copy_(new_free_list_head)
306+
# Determine grid size
307+
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
308+
num_blocks = min(self.max_num_tokens, max_blocks)
309+
grid = (num_blocks,)
310+
311+
# Create temporary tensor for new head
312+
new_free_list_head = torch.empty(1, dtype=torch.int64, device=self.device)
313+
314+
# Launch paged stash copy kernel
315+
_paged_stash_copy_kernel[grid](
316+
tensor_to_copy,
317+
paged_stash_buffer.buffer,
318+
self.num_tokens_tensor,
319+
paged_stash_buffer.free_list,
320+
paged_stash_buffer.free_list_head,
321+
paged_stash_buffer.free_list_tail,
322+
paged_stash_buffer.free_list_capacity,
323+
self.page_record, # Triton kernel will populate page_record
324+
paged_stash_buffer.overflow,
325+
new_free_list_head,
326+
PAGE_SIZE=self.page_size,
327+
HIDDEN_SIZE=self.hidden_size,
328+
BLOCK_SIZE=BLOCK_SIZE,
329+
)
330+
331+
# Update free list head
332+
paged_stash_buffer.free_list_head.copy_(new_free_list_head)
357333

358334
# Save reference to original tensor
359335
self._original_tensor = self._tensor
@@ -384,48 +360,32 @@ def reload_from_stash(self, paged_stash_buffer: PagedStashBuffer, max_blocks=204
384360
self._tensor = torch.empty(self.original_shape, dtype=self.dtype, device=self.device)
385361
tensor_to_reload = self._tensor
386362

387-
# Split tensor into two parts: D2D portion and Triton portion
388-
# Use max_num_tokens for consistency with stash
389-
d2d_tokens = min(self.num_d2d_pages * self.page_size, self.max_num_tokens)
390-
triton_tokens = self.max_num_tokens - d2d_tokens
391-
392-
# Perform both D2D copy and Triton kernel together
393-
# Part 1: Copy first d2d_tokens from static_tensor using native PyTorch
394-
if d2d_tokens > 0 and self.static_tensor is not None:
395-
tensor_to_reload[:d2d_tokens] = self.static_tensor[:d2d_tokens]
396-
397-
# Part 2: Copy remaining tokens using Triton kernel
398-
if triton_tokens > 0:
399-
triton_tensor = tensor_to_reload[d2d_tokens:self.max_num_tokens]
400-
# Use actual num_tokens for the kernel (how many tokens to actually copy)
401-
triton_num_tokens = self.num_tokens_tensor - d2d_tokens
402-
403-
# Determine grid size
404-
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
405-
num_blocks = min(triton_tokens, max_blocks)
406-
grid = (num_blocks,)
407-
408-
# Create temporary tensor for new tail
409-
new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device)
410-
411-
# Launch paged stash pop kernel
412-
_paged_stash_pop_kernel[grid](
413-
paged_stash_buffer.buffer,
414-
triton_tensor,
415-
triton_num_tokens,
416-
self.page_record, # Triton kernel will read from page_record
417-
paged_stash_buffer.free_list,
418-
paged_stash_buffer.free_list_head,
419-
paged_stash_buffer.free_list_tail,
420-
paged_stash_buffer.free_list_capacity,
421-
new_free_list_tail,
422-
PAGE_SIZE=self.page_size,
423-
HIDDEN_SIZE=self.hidden_size,
424-
BLOCK_SIZE=BLOCK_SIZE,
425-
)
426-
427-
# Update free list tail
428-
paged_stash_buffer.free_list_tail.copy_(new_free_list_tail)
363+
# Determine grid size
364+
BLOCK_SIZE = GLOBAL_BLOCK_SIZE
365+
num_blocks = min(self.max_num_tokens, max_blocks)
366+
grid = (num_blocks,)
367+
368+
# Create temporary tensor for new tail
369+
new_free_list_tail = torch.empty(1, dtype=torch.int64, device=self.device)
370+
371+
# Launch paged stash pop kernel
372+
_paged_stash_pop_kernel[grid](
373+
paged_stash_buffer.buffer,
374+
tensor_to_reload,
375+
self.num_tokens_tensor,
376+
self.page_record, # Triton kernel will read from page_record
377+
paged_stash_buffer.free_list,
378+
paged_stash_buffer.free_list_head,
379+
paged_stash_buffer.free_list_tail,
380+
paged_stash_buffer.free_list_capacity,
381+
new_free_list_tail,
382+
PAGE_SIZE=self.page_size,
383+
HIDDEN_SIZE=self.hidden_size,
384+
BLOCK_SIZE=BLOCK_SIZE,
385+
)
386+
387+
# Update free list tail
388+
paged_stash_buffer.free_list_tail.copy_(new_free_list_tail)
429389

430390

431391
class PP_PreScheduleFunction(torch.autograd.Function):
@@ -555,9 +515,6 @@ def __init__(self):
555515

556516
# Page size for paged memory management
557517
self.page_size = int(os.getenv('PAGED_STASH_PAGE_SIZE', '64')) # Default 64 tokens per page
558-
559-
# Number of pages to copy using native PyTorch (D2D)
560-
self.num_d2d_pages = int(os.getenv('NUM_D2D_PAGES', '0')) # Default 0 (all Triton)
561518

562519
@property
563520
def pack_stream(self):
@@ -765,7 +722,6 @@ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
765722
layer_name=self._current_layer_name,
766723
max_tokens=self.max_num_tokens,
767724
page_size=self.page_size,
768-
num_d2d_pages=self.num_d2d_pages
769725
)
770726

771727
if self.status == 'captured':

megatron/core/transformer/moe/token_dispatcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
permute,
3434
sort_chunks_by_idxs,
3535
unpermute,
36-
drop_routing_map_triton,
3736
)
3837
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
3938
from megatron.core.transformer.transformer_config import TransformerConfig

0 commit comments

Comments
 (0)