diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 02298350b..074811128 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -269,6 +269,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore + list(self.layers.values())[-1].set_modules_to_forward_prefetch([self.norm, self.lm_head]) # type: ignore self._to_empty_meta() diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 2f9abe8c4..22e4e6438 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -779,6 +779,7 @@ def fully_shard( offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore + list(self.layers.values())[-1].set_modules_to_forward_prefetch([self.norm, self.lm_head]) # type: ignore for _, module in self.named_modules(): if isinstance(module, nn.Embedding):