2727 AverageMeter ,
2828)
2929from world_models .utils .jepa_utils import repeat_interleave_batch
30- from world_models .datasets .imagenet1k import make_imagenet1k
30+ from world_models .datasets .imagenet1k import make_imagenet1k , make_imagefolder
31+ from world_models .datasets .cifar10 import make_cifar10
3132from world_models .helpers .jepa_helper import load_checkpoint , init_model , init_opt
3233from world_models .transforms .transforms import make_transforms
3334from world_models .configs .jepa_config import JEPAConfig
@@ -181,20 +182,52 @@ def main(args, resume_preempt=False):
181182 )
182183
183184 # -- init data-loaders/samplers
184- _ , unsupervised_loader , unsupervised_sampler = make_imagenet1k (
185- transform = transform ,
186- batch_size = batch_size ,
187- collator = mask_collator ,
188- pin_mem = pin_mem ,
189- training = True ,
190- num_workers = num_workers ,
191- world_size = world_size ,
192- rank = rank ,
193- root_path = root_path ,
194- image_folder = image_folder ,
195- copy_data = copy_data ,
196- drop_last = True ,
197- )
185+ dataset_type = args ["data" ]["dataset" ]
186+ val_split = args ["data" ]["val_split" ]
187+ download = args ["data" ].get ("download" , False )
188+ if dataset_type .lower () == "imagenet" :
189+ _ , unsupervised_loader , unsupervised_sampler = make_imagenet1k (
190+ transform = transform ,
191+ batch_size = batch_size ,
192+ collator = mask_collator ,
193+ pin_mem = pin_mem ,
194+ training = True ,
195+ num_workers = num_workers ,
196+ world_size = world_size ,
197+ rank = rank ,
198+ root_path = root_path ,
199+ image_folder = image_folder ,
200+ copy_data = copy_data ,
201+ drop_last = True ,
202+ )
203+ elif dataset_type .lower () == "cifar10" :
204+ _ , unsupervised_loader , unsupervised_sampler = make_cifar10 (
205+ transform = transform ,
206+ batch_size = batch_size ,
207+ collator = mask_collator ,
208+ pin_mem = pin_mem ,
209+ num_workers = num_workers ,
210+ world_size = world_size ,
211+ rank = rank ,
212+ root_path = root_path ,
213+ drop_last = True ,
214+ train = True ,
215+ download = download , # pass through
216+ )
217+ else :
218+ _ , unsupervised_loader , unsupervised_sampler = make_imagefolder (
219+ transform = transform ,
220+ batch_size = batch_size ,
221+ collator = mask_collator ,
222+ pin_mem = pin_mem ,
223+ num_workers = num_workers ,
224+ world_size = world_size ,
225+ rank = rank ,
226+ root_path = root_path ,
227+ image_folder = image_folder ,
228+ drop_last = True ,
229+ val_split = val_split ,
230+ )
198231 ipe = len (unsupervised_loader )
199232
200233 # -- init optimizer and scheduler
@@ -212,9 +245,17 @@ def main(args, resume_preempt=False):
212245 ipe_scale = ipe_scale ,
213246 use_bfloat16 = use_bfloat16 ,
214247 )
215- encoder = DistributedDataParallel (encoder , static_graph = True )
216- predictor = DistributedDataParallel (predictor , static_graph = True )
217- target_encoder = DistributedDataParallel (target_encoder )
248+
249+ is_distributed = (
250+ torch .distributed .is_available ()
251+ and torch .distributed .is_initialized ()
252+ and world_size > 1
253+ )
254+ if is_distributed :
255+ encoder = DistributedDataParallel (encoder , static_graph = True )
256+ predictor = DistributedDataParallel (predictor , static_graph = True )
257+ target_encoder = DistributedDataParallel (target_encoder )
258+ # keep modules unwrapped when not distributed
218259 for p in target_encoder .parameters ():
219260 p .requires_grad = False
220261
@@ -328,7 +369,8 @@ def loss_fn(z, h):
328369 else :
329370 loss .backward ()
330371 optimizer .step ()
331- grad_stats = grad_logger (encoder .named_parameters ())
372+ enc_for_log = encoder .module if is_distributed else encoder
373+ grad_stats = grad_logger (enc_for_log .named_parameters ())
332374 optimizer .zero_grad ()
333375
334376 # Step 3. momentum update of target encoder
0 commit comments