Skip to content

Commit fd6bbef

Browse files
committed
Major cleanup and refactor
Get rid of legacy names like packed offloading Move the main code body of paged stash to transformer/moe/
1 parent 0b3ccfd commit fd6bbef

File tree

9 files changed

+59
-1245
lines changed

9 files changed

+59
-1245
lines changed

megatron/core/full_cuda_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88

99
from megatron.core.tensor_parallel.random import get_all_rng_states
10-
from megatron.core.pipeline_parallel.moe_packed_offload import (
11-
packed_moe_expert_offloading_reset,
10+
from megatron.core.transformer.moe.paged_stash import (
11+
paged_stash_reset,
1212
)
1313

1414
logger = logging.getLogger(__name__)
@@ -101,11 +101,11 @@ class FullCudaGraphWrapper:
101101
cuda_graph = {'training': None, 'validation': None}
102102
result = {'training': None, 'validation': None}
103103

104-
def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, packed_moe_expert_offloading=False):
104+
def __init__(self, forward_backward_func, cuda_graph_warmup_steps=1, moe_paged_stash=False):
105105
self.forward_backward_func = forward_backward_func
106106
self.static_loader = StaticBufferLoader()
107107
self.cuda_graph_warmup_steps = cuda_graph_warmup_steps
108-
self.packed_moe_expert_offloading = packed_moe_expert_offloading
108+
self.moe_paged_stash = moe_paged_stash
109109

110110
def data_read(self, data_iterator, model, training, num_microbatches):
111111
"""Read all microbatch inputs from Dataloader and copy to static buffers."""
@@ -188,15 +188,15 @@ def __call__(self, *args, **kwargs):
188188
if FullCudaGraphWrapper.cuda_graph[training_str] is None:
189189
FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs)
190190
else:
191-
packed_moe_expert_offloading_reset(enabled=self.packed_moe_expert_offloading and training)
191+
paged_stash_reset(enabled=self.moe_paged_stash and training)
192192
FullCudaGraphWrapper.cuda_graph[training_str].replay()
193193
self.speculative_cuda_graph_check(model)
194194
self.next_iter(training_str)
195195
return FullCudaGraphWrapper.result[training_str]
196196

197197
def speculative_cuda_graph_check(self, model):
198198
''' check speculative execution modules '''
199-
if self.packed_moe_expert_offloading:
199+
if self.moe_paged_stash:
200200
# Check if there is any overflow in the receiving buffer
201201
over_budget = torch.zeros(1, dtype=torch.bool, device='cuda')
202202
for model_chunk in model:

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
1212
fine_grained_offloading_set_last_layer,
1313
)
14-
from megatron.core.pipeline_parallel.moe_packed_offload import (
15-
packed_moe_expert_offloading_set_last_layer,
14+
from megatron.core.transformer.moe.paged_stash import (
15+
paged_stash_set_last_layer,
1616
)
1717
from megatron.core.pipeline_parallel.utils import (
1818
AbstractSchedulePlan,
@@ -501,8 +501,8 @@ def run(
501501
fine_grained_offloading_set_last_layer(i == f_num_layers - 1)
502502
b_layer = b_schedule_plan.pop_layer()
503503
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b")
504-
if f_layer.layer.config.packed_moe_expert_offloading:
505-
packed_moe_expert_offloading_set_last_layer(i == f_num_layers - 1)
504+
if f_layer.layer.config.moe_paged_stash:
505+
paged_stash_set_last_layer(i == f_num_layers - 1)
506506
f_input, b_grad = TransformerLayerSchedulePlan.run(
507507
f_layer,
508508
b_layer,
@@ -531,8 +531,8 @@ def run(
531531
torch.cuda.nvtx.range_push(f"layer_{i}f")
532532
if f_layer.layer.config.fine_grained_activation_offloading:
533533
fine_grained_offloading_set_last_layer(i == f_num_layers - 1)
534-
if f_layer.layer.config.packed_moe_expert_offloading:
535-
packed_moe_expert_offloading_set_last_layer(i == f_num_layers - 1)
534+
if f_layer.layer.config.moe_paged_stash:
535+
paged_stash_set_last_layer(i == f_num_layers - 1)
536536
f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input)
537537
torch.cuda.nvtx.range_pop()
538538

megatron/core/models/gpt/gpt_model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
2222
fine_grained_offloading_init_chunk_handler,
2323
)
24-
from megatron.core.pipeline_parallel.moe_packed_offload import (
25-
packed_moe_expert_offloading_init_chunk_handler,
24+
from megatron.core.transformer.moe.paged_stash import (
25+
paged_stash_init_chunk_handler,
2626
)
2727
from megatron.core.process_groups_config import ProcessGroupCollection
2828
from megatron.core.quantization.utils import get_quant_config_or_none
@@ -435,9 +435,9 @@ def preprocess_for_fine_grained_offloading(self):
435435
param.offloading_activation = False
436436
self.disable_param_offloading = False
437437

438-
def preprocess_for_packed_moe_expert_offloading(self):
439-
"""Preprocess for packed moe expert offloading."""
440-
return packed_moe_expert_offloading_init_chunk_handler(
438+
def preprocess_for_paged_stash(self):
439+
"""Preprocess for paged stash."""
440+
return paged_stash_init_chunk_handler(
441441
vp_size=self.config.virtual_pipeline_model_parallel_size,
442442
vp_stage=self.vp_stage,
443443
)
@@ -470,8 +470,8 @@ def forward(
470470
if self.config.fine_grained_activation_offloading:
471471
self.preprocess_for_fine_grained_offloading()
472472

473-
if self.config.packed_moe_expert_offloading:
474-
self.preprocess_for_packed_moe_expert_offloading()
473+
if self.config.moe_paged_stash:
474+
self.preprocess_for_paged_stash()
475475

476476
inference_context = deprecate_inference_params(inference_context, inference_params)
477477

@@ -770,8 +770,8 @@ def build_schedule_plan(
770770

771771
if self.config.fine_grained_activation_offloading:
772772
self.preprocess_for_fine_grained_offloading()
773-
if self.config.packed_moe_expert_offloading:
774-
self.preprocess_for_packed_moe_expert_offloading()
773+
if self.config.moe_paged_stash:
774+
self.preprocess_for_paged_stash()
775775

776776
from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan
777777

0 commit comments

Comments
 (0)