File tree Expand file tree Collapse file tree 2 files changed +2
-0
lines changed
Expand file tree Collapse file tree 2 files changed +2
-0
lines changed Original file line number Diff line number Diff line change @@ -269,6 +269,7 @@ def fully_shard(
269269 offload_policy = CPUOffloadPolicy () if self .fsdp_config .cpu_offload else None ,
270270 )
271271 self .set_modules_to_forward_prefetch ([self .embed_tokens , self .layers ["0" ]]) # type: ignore
272+ list (self .layers .values ())[- 1 ].set_modules_to_forward_prefetch ([self .norm , self .lm_head ]) # type: ignore
272273
273274 self ._to_empty_meta ()
274275
Original file line number Diff line number Diff line change @@ -779,6 +779,7 @@ def fully_shard(
779779 offload_policy = CPUOffloadPolicy () if self .fsdp_config .cpu_offload else None ,
780780 )
781781 self .set_modules_to_forward_prefetch ([self .embed_tokens , self .layers ["0" ]]) # type: ignore
782+ list (self .layers .values ())[- 1 ].set_modules_to_forward_prefetch ([self .norm , self .lm_head ]) # type: ignore
782783
783784 for _ , module in self .named_modules ():
784785 if isinstance (module , nn .Embedding ):
You can’t perform that action at this time.
0 commit comments