Skip to content

Commit 10e3805

Browse files
committed
[Enhance] Add layers[-1] to norm & lm_head prefetch
1 parent cdc5b30 commit 10e3805

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

xtuner/v1/model/dense/dense.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def fully_shard(
270270
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else OffloadPolicy(),
271271
)
272272
self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore
273+
list(self.layers.values())[-1].set_modules_to_forward_prefetch([self.norm, self.lm_head]) # type: ignore
273274

274275
self._to_empty_meta()
275276

xtuner/v1/model/moe/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,7 @@ def fully_shard(
780780
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else OffloadPolicy(),
781781
)
782782
self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore
783+
list(self.layers.values())[-1].set_modules_to_forward_prefetch([self.norm, self.lm_head]) # type: ignore
783784

784785
for _, module in self.named_modules():
785786
if isinstance(module, nn.Embedding):

0 commit comments

Comments
 (0)