File tree Expand file tree Collapse file tree 2 files changed +2
-0
lines changed
Expand file tree Collapse file tree 2 files changed +2
-0
lines changed Original file line number Diff line number Diff line change @@ -270,6 +270,7 @@ def fully_shard(
270270 offload_policy = CPUOffloadPolicy () if self .fsdp_config .cpu_offload else OffloadPolicy (),
271271 )
272272 self .set_modules_to_forward_prefetch ([self .embed_tokens , self .layers ["0" ]]) # type: ignore
273+ list (self .layers .values ())[- 1 ].set_modules_to_forward_prefetch ([self .norm , self .lm_head ]) # type: ignore
273274
274275 self ._to_empty_meta ()
275276
Original file line number Diff line number Diff line change @@ -780,6 +780,7 @@ def fully_shard(
780780 offload_policy = CPUOffloadPolicy () if self .fsdp_config .cpu_offload else OffloadPolicy (),
781781 )
782782 self .set_modules_to_forward_prefetch ([self .embed_tokens , self .layers ["0" ]]) # type: ignore
783+ list (self .layers .values ())[- 1 ].set_modules_to_forward_prefetch ([self .norm , self .lm_head ]) # type: ignore
783784
784785 for _ , module in self .named_modules ():
785786 if isinstance (module , nn .Embedding ):
You can’t perform that action at this time.
0 commit comments