Skip to content

Commit f9dee21

Browse files
committed
[Enhance] Add layers[-1] to norm & lm_head prefetch
1 parent 8dae899 commit f9dee21

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

xtuner/v1/model/dense/dense.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def fully_shard(
269269
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
270270
)
271271
self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore
272+
list(self.layers.values())[-1].set_modules_to_forward_prefetch([self.norm, self.lm_head]) # type: ignore
272273

273274
self._to_empty_meta()
274275

xtuner/v1/model/moe/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ def fully_shard(
779779
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
780780
)
781781
self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore
782+
list(self.layers.values())[-1].set_modules_to_forward_prefetch([self.norm, self.lm_head]) # type: ignore
782783

783784
for _, module in self.named_modules():
784785
if isinstance(module, nn.Embedding):

0 commit comments

Comments
 (0)