Skip to content

Commit 741052b

Browse files
committed
Use get_attr_wrapped_model util to access moe and mtp layers
1 parent 88ccf26 commit 741052b

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

megatron/core/transformer/moe/paged_stash.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer
1111
from megatron.core.full_cuda_graph import FullCudaGraphWrapper
12+
from megatron.core.utils import get_attr_wrapped_model
1213

1314
GLOBAL_BLOCK_SIZE = 1024
1415
SCALE_INV_BLOCK_SIZE = 32
@@ -1097,14 +1098,17 @@ def __init__(self, config, copy_main_params, model, optimizer, forward_backward_
10971098
self.forward_backward_func = forward_backward_func
10981099
self.moe_layers = []
10991100
for model_chunk in self.model:
1100-
for layer in model_chunk.module.module.decoder.layers:
1101+
model_with_decoder = get_attr_wrapped_model(
1102+
model_chunk, "decoder", allow_none=False, return_model_obj=True
1103+
)
1104+
for layer in model_with_decoder.decoder.layers:
11011105
mlp = layer.mlp
11021106
if hasattr(mlp, 'token_dispatcher') and hasattr(
11031107
mlp.token_dispatcher, 'check_over_budget'
11041108
):
11051109
self.moe_layers.append(mlp)
1106-
if model_chunk.module.module.mtp_process:
1107-
for layer in model_chunk.module.module.mtp.layers:
1110+
if model_with_decoder.mtp_process:
1111+
for layer in model_with_decoder.mtp.layers:
11081112
mlp = layer.mtp_model_layer.mlp
11091113
if hasattr(mlp, 'token_dispatcher') and hasattr(
11101114
mlp.token_dispatcher, 'check_over_budget'

0 commit comments

Comments
 (0)