2020
2121from torchtnt .framework .state import State
2222from torchtnt .utils .lr_scheduler import TLRScheduler
23- from torchtnt .utils .prepare_module import _is_fsdp_module , FSDPOptimizerWrapper
23+ from torchtnt .utils .prepare_module import (
24+ _is_fsdp2_module ,
25+ _is_fsdp_module ,
26+ FSDP2OptimizerWrapper ,
27+ FSDPOptimizerWrapper ,
28+ )
2429from torchtnt .utils .progress import Progress
2530from torchtnt .utils .stateful import MetricStateful , Stateful
2631
@@ -199,13 +204,27 @@ def __delattr__(self, name: str) -> None:
199204
200205 def _construct_tracked_optimizers_and_schedulers (
201206 self ,
202- ) -> Dict [str , Union [torch .optim .Optimizer , FSDPOptimizerWrapper , TLRScheduler ]]:
207+ ) -> Dict [
208+ str ,
209+ Union [
210+ torch .optim .Optimizer ,
211+ FSDPOptimizerWrapper ,
212+ FSDP2OptimizerWrapper ,
213+ TLRScheduler ,
214+ ],
215+ ]:
203216 """
204- Combines tracked optimizers and schedulers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper.
217+ Combines tracked optimizers and schedulers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper/FSDP2OptimizerWrapper .
205218 """
206219 # construct custom tracked optimizers with FSDP optimizers
207220 tracked_optimizers_and_schedulers : Dict [
208- str , Union [torch .optim .Optimizer , FSDPOptimizerWrapper , TLRScheduler ]
221+ str ,
222+ Union [
223+ torch .optim .Optimizer ,
224+ FSDPOptimizerWrapper ,
225+ FSDP2OptimizerWrapper ,
226+ TLRScheduler ,
227+ ],
209228 ] = {}
210229 tracked_optimizers_and_schedulers .update (self ._construct_tracked_optimizers ())
211230
@@ -224,25 +243,38 @@ def _construct_tracked_optimizers_and_schedulers(
224243
225244 def _construct_tracked_optimizers (
226245 self ,
227- ) -> Dict [str , Union [torch .optim .Optimizer , FSDPOptimizerWrapper ]]:
246+ ) -> Dict [
247+ str , Union [torch .optim .Optimizer , FSDPOptimizerWrapper , FSDP2OptimizerWrapper ]
248+ ]:
228249 """
229- Constructs tracked optimizers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper.
250+ Constructs tracked optimizers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper/FSDP2OptimizerWrapper .
230251 """
231- fsdp_tracked_optimizers : Dict [str , FSDPOptimizerWrapper ] = {}
252+ fsdp_tracked_optimizers : Dict [
253+ str , Union [FSDPOptimizerWrapper , FSDP2OptimizerWrapper ]
254+ ] = {}
232255 for module in self .tracked_modules ().values ():
233256 if _is_fsdp_module (module ):
234257 # find optimizers for module, if exists
235258 optimizer_list = _find_optimizers_for_module (
236259 module , self .tracked_optimizers ()
237260 )
261+
262+ is_fsdp2 = _is_fsdp2_module (module )
263+
238264 for optim_name , optimizer in optimizer_list :
239- fsdp_tracked_optimizers [optim_name ] = FSDPOptimizerWrapper (
240- module , optimizer
241- )
265+ if is_fsdp2 :
266+ fsdp_tracked_optimizers [optim_name ] = FSDP2OptimizerWrapper (
267+ module , optimizer
268+ )
269+ else :
270+ fsdp_tracked_optimizers [optim_name ] = FSDPOptimizerWrapper (
271+ module , optimizer
272+ )
242273
243274 # construct custom tracked optimizers with FSDP optimizers
244275 tracked_optimizers : Dict [
245- str , Union [torch .optim .Optimizer , FSDPOptimizerWrapper ]
276+ str ,
277+ Union [torch .optim .Optimizer , FSDPOptimizerWrapper , FSDP2OptimizerWrapper ],
246278 ] = {
247279 key : value
248280 for key , value in self .tracked_optimizers ().items ()
0 commit comments