@@ -171,6 +171,7 @@ def __init__(
171171 checkpoint_folder : str = './checkpoints' ,
172172 overwrite_checkpoints : bool = False ,
173173 fabric_kwargs : dict = dict (),
174+ distributed_eval : bool = True ,
174175 fp16 : bool = False ,
175176 use_ema : bool = True ,
176177 ema_kwargs : dict = dict (
@@ -201,10 +202,16 @@ def __init__(
201202 self .fabric = fabric
202203 fabric .launch ()
203204
205+ # whether evaluating only on root node or not
206+ # to save on each machine keeping track of EMA
207+
208+ self .distributed_eval = distributed_eval
209+ self .will_eval_or_test = self .is_main or distributed_eval
210+
204211 # exponential moving average
205212
206213 self .ema_model = None
207- self .has_ema = self .is_main and use_ema
214+ self .has_ema = self .will_eval_or_test and use_ema
208215
209216 if self .has_ema :
210217 self .ema_model = EMA (
@@ -282,16 +289,18 @@ def __init__(
282289 self .valid_every = valid_every
283290
284291 self .needs_valid = exists (valid_dataset )
292+ self .valid_dataloader = None
285293
286- if self .needs_valid and self .is_main :
294+ if self .needs_valid and self .will_eval_or_test :
287295 self .valid_dataset_size = len (valid_dataset )
288296 self .valid_dataloader = DataLoader_ (valid_dataset , batch_size = batch_size )
289297
290298 # testing dataloader on EMA model
291299
292300 self .needs_test = exists (test_dataset )
301+ self .test_dataloader = None
293302
294- if self .needs_test and self .is_main :
303+ if self .needs_test and self .will_eval_or_test :
295304 self .test_dataset_size = len (test_dataset )
296305 self .test_dataloader = DataLoader_ (test_dataset , batch_size = batch_size )
297306
@@ -306,6 +315,12 @@ def __init__(
306315
307316 fabric .setup_dataloaders (self .dataloader )
308317
318+ if exists (self .valid_dataloader ) and self .distributed_eval :
319+ fabric .setup_dataloaders (self .valid_dataloader )
320+
321+ if exists (self .test_dataloader ) and self .distributed_eval :
322+ fabric .setup_dataloaders (self .test_dataloader )
323+
309324 # scheduler
310325
311326 if not exists (scheduler ):
@@ -555,7 +570,7 @@ def __call__(
555570 # maybe validate, for now, only on main with EMA model
556571
557572 if (
558- self .is_main and
573+ self .will_eval_or_test and
559574 self .needs_valid and
560575 divisible_by (self .steps , self .valid_every )
561576 ):
@@ -585,6 +600,11 @@ def __call__(
585600
586601 valid_loss_breakdown = {f'valid_{ k } ' :v for k , v in valid_loss_breakdown .items ()}
587602
603+ # reduce valid loss breakdown
604+
605+ if self .distributed_eval :
606+ valid_loss_breakdown = self .fabric .all_reduce (valid_loss_breakdown , reduce_op = 'sum' )
607+
588608 # log
589609
590610 self .log (** valid_loss_breakdown )
@@ -598,7 +618,7 @@ def __call__(
598618
599619 # maybe test
600620
601- if self .is_main and self .needs_test :
621+ if self .will_eval_or_test and self .needs_test :
602622 eval_model = default (self .ema_model , self .model )
603623
604624 with torch .no_grad (), to_device_and_back (eval_model , self .device ):
@@ -625,6 +645,11 @@ def __call__(
625645
626646 test_loss_breakdown = {f'test_{ k } ' :v for k , v in test_loss_breakdown .items ()}
627647
648+ # reduce
649+
650+ if self .distributed_eval :
651+ test_loss_breakdown = self .fabric .all_reduce (test_loss_breakdown , reduce_op = 'sum' )
652+
628653 # log
629654
630655 self .log (** test_loss_breakdown )
0 commit comments