Skip to content

Commit 0ac3cac

Browse files
Changes and add datasets
1 parent f03ddcb commit 0ac3cac

File tree

8 files changed

+174
-27
lines changed

8 files changed

+174
-27
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ results/
44
dist/*
55
pytorch_world.egg-info/
66
world_models/models/data/
7+
cifar

jepa_try.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
from world_models.models.jepa_agent import JEPAAgent
22

3-
agent = JEPAAgent(
4-
folder="results/jepa_try",
5-
write_tag="jepa_try",
6-
)
7-
agent.train()
3+
if __name__ == "__main__":
4+
agent = JEPAAgent(
5+
dataset="cifar10",
6+
root_path=r"E:\pytorch-world\cifar",
7+
download=True,
8+
folder="results/cifar_jepa",
9+
write_tag="cifar_jepa",
10+
batch_size=16,
11+
pin_mem=False,
12+
crop_size=32,
13+
patch_size=4,
14+
enc_mask_scale=(0.05, 0.15),
15+
pred_mask_scale=(0.05, 0.15),
16+
min_keep=1,
17+
allow_overlap=True,
18+
num_workers=0,
19+
epochs=25,
20+
)
21+
agent.train()

world_models/configs/jepa_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def __init__(self):
1919
self.pred_emb_dim: int = 384
2020

2121
# data
22+
self.dataset: str = "imagenet" # "imagenet" or "imagefolder"
23+
self.val_split: float | None = (
24+
None # optional fraction for val split when using imagefolder
25+
)
2226
self.use_gaussian_blur: bool = True
2327
self.use_horizontal_flip: bool = True
2428
self.use_color_distortion: bool = True
@@ -30,6 +34,7 @@ def __init__(self):
3034
self.image_folder: str = "train"
3135
self.crop_size: int = 224
3236
self.crop_scale: Tuple[float, float] = (0.67, 1.0)
37+
self.download: bool = False # allow CIFAR10 download if missing
3338

3439
# mask
3540
self.allow_overlap: bool = False
@@ -68,6 +73,8 @@ def to_dict(self) -> Dict[str, Dict[str, Any]]:
6873
"pred_emb_dim": self.pred_emb_dim,
6974
},
7075
"data": {
76+
"dataset": self.dataset,
77+
"val_split": self.val_split,
7178
"use_gaussian_blur": self.use_gaussian_blur,
7279
"use_horizontal_flip": self.use_horizontal_flip,
7380
"use_color_distortion": self.use_color_distortion,
@@ -79,6 +86,7 @@ def to_dict(self) -> Dict[str, Dict[str, Any]]:
7986
"image_folder": self.image_folder,
8087
"crop_size": self.crop_size,
8188
"crop_scale": self.crop_scale,
89+
"download": self.download, # new
8290
},
8391
"mask": {
8492
"allow_overlap": self.allow_overlap,

world_models/datasets/cifar10.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
from torchvision.datasets import CIFAR10
3+
from logging import getLogger
4+
5+
logger = getLogger()
6+
7+
8+
def make_cifar10(
9+
transform,
10+
batch_size,
11+
collator=None,
12+
pin_mem=True,
13+
num_workers=8,
14+
world_size=1,
15+
rank=0,
16+
root_path=None,
17+
drop_last=True,
18+
train=True,
19+
download=False, # new
20+
):
21+
dataset = CIFAR10(
22+
root=root_path,
23+
train=train,
24+
download=download,
25+
transform=transform,
26+
)
27+
dist_sampler = torch.utils.data.distributed.DistributedSampler(
28+
dataset=dataset, num_replicas=world_size, rank=rank
29+
)
30+
data_loader = torch.utils.data.DataLoader(
31+
dataset,
32+
collate_fn=collator,
33+
sampler=dist_sampler,
34+
batch_size=batch_size,
35+
drop_last=drop_last,
36+
pin_memory=pin_mem,
37+
num_workers=num_workers,
38+
persistent_workers=False,
39+
)
40+
logger.info("CIFAR10 data loader created")
41+
return dataset, data_loader, dist_sampler

world_models/datasets/imagenet1k.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
import torchvision
11+
from torch.utils.data import random_split
1112

1213
_GLOBAL_SEED = 0
1314
logger = getLogger()
@@ -212,3 +213,41 @@ def copy_imgnt_locally(
212213
logger.info(f"{local_rank}: Checking {tmp_sgnl_file}")
213214

214215
return data_path
216+
217+
218+
def make_imagefolder(
219+
transform,
220+
batch_size,
221+
collator=None,
222+
pin_mem=True,
223+
num_workers=8,
224+
world_size=1,
225+
rank=0,
226+
root_path=None,
227+
image_folder=None,
228+
drop_last=True,
229+
val_split: float | None = None,
230+
):
231+
dataset = torchvision.datasets.ImageFolder(
232+
root=os.path.join(root_path, image_folder) if image_folder else root_path,
233+
transform=transform,
234+
)
235+
if val_split:
236+
val_size = int(len(dataset) * val_split)
237+
train_size = len(dataset) - val_size
238+
dataset, _ = random_split(dataset, [train_size, val_size])
239+
dist_sampler = torch.utils.data.distributed.DistributedSampler(
240+
dataset=dataset, num_replicas=world_size, rank=rank
241+
)
242+
data_loader = torch.utils.data.DataLoader(
243+
dataset,
244+
collate_fn=collator,
245+
sampler=dist_sampler,
246+
batch_size=batch_size,
247+
drop_last=drop_last,
248+
pin_memory=pin_mem,
249+
num_workers=num_workers,
250+
persistent_workers=False,
251+
)
252+
logger.info("ImageFolder data loader created")
253+
return dataset, data_loader, dist_sampler

world_models/training/train_jepa.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
AverageMeter,
2828
)
2929
from world_models.utils.jepa_utils import repeat_interleave_batch
30-
from world_models.datasets.imagenet1k import make_imagenet1k
30+
from world_models.datasets.imagenet1k import make_imagenet1k, make_imagefolder
31+
from world_models.datasets.cifar10 import make_cifar10
3132
from world_models.helpers.jepa_helper import load_checkpoint, init_model, init_opt
3233
from world_models.transforms.transforms import make_transforms
3334
from world_models.configs.jepa_config import JEPAConfig
@@ -181,20 +182,52 @@ def main(args, resume_preempt=False):
181182
)
182183

183184
# -- init data-loaders/samplers
184-
_, unsupervised_loader, unsupervised_sampler = make_imagenet1k(
185-
transform=transform,
186-
batch_size=batch_size,
187-
collator=mask_collator,
188-
pin_mem=pin_mem,
189-
training=True,
190-
num_workers=num_workers,
191-
world_size=world_size,
192-
rank=rank,
193-
root_path=root_path,
194-
image_folder=image_folder,
195-
copy_data=copy_data,
196-
drop_last=True,
197-
)
185+
dataset_type = args["data"]["dataset"]
186+
val_split = args["data"]["val_split"]
187+
download = args["data"].get("download", False)
188+
if dataset_type.lower() == "imagenet":
189+
_, unsupervised_loader, unsupervised_sampler = make_imagenet1k(
190+
transform=transform,
191+
batch_size=batch_size,
192+
collator=mask_collator,
193+
pin_mem=pin_mem,
194+
training=True,
195+
num_workers=num_workers,
196+
world_size=world_size,
197+
rank=rank,
198+
root_path=root_path,
199+
image_folder=image_folder,
200+
copy_data=copy_data,
201+
drop_last=True,
202+
)
203+
elif dataset_type.lower() == "cifar10":
204+
_, unsupervised_loader, unsupervised_sampler = make_cifar10(
205+
transform=transform,
206+
batch_size=batch_size,
207+
collator=mask_collator,
208+
pin_mem=pin_mem,
209+
num_workers=num_workers,
210+
world_size=world_size,
211+
rank=rank,
212+
root_path=root_path,
213+
drop_last=True,
214+
train=True,
215+
download=download, # pass through
216+
)
217+
else:
218+
_, unsupervised_loader, unsupervised_sampler = make_imagefolder(
219+
transform=transform,
220+
batch_size=batch_size,
221+
collator=mask_collator,
222+
pin_mem=pin_mem,
223+
num_workers=num_workers,
224+
world_size=world_size,
225+
rank=rank,
226+
root_path=root_path,
227+
image_folder=image_folder,
228+
drop_last=True,
229+
val_split=val_split,
230+
)
198231
ipe = len(unsupervised_loader)
199232

200233
# -- init optimizer and scheduler
@@ -212,9 +245,17 @@ def main(args, resume_preempt=False):
212245
ipe_scale=ipe_scale,
213246
use_bfloat16=use_bfloat16,
214247
)
215-
encoder = DistributedDataParallel(encoder, static_graph=True)
216-
predictor = DistributedDataParallel(predictor, static_graph=True)
217-
target_encoder = DistributedDataParallel(target_encoder)
248+
249+
is_distributed = (
250+
torch.distributed.is_available()
251+
and torch.distributed.is_initialized()
252+
and world_size > 1
253+
)
254+
if is_distributed:
255+
encoder = DistributedDataParallel(encoder, static_graph=True)
256+
predictor = DistributedDataParallel(predictor, static_graph=True)
257+
target_encoder = DistributedDataParallel(target_encoder)
258+
# keep modules unwrapped when not distributed
218259
for p in target_encoder.parameters():
219260
p.requires_grad = False
220261

@@ -328,7 +369,8 @@ def loss_fn(z, h):
328369
else:
329370
loss.backward()
330371
optimizer.step()
331-
grad_stats = grad_logger(encoder.named_parameters())
372+
enc_for_log = encoder.module if is_distributed else encoder
373+
grad_stats = grad_logger(enc_for_log.named_parameters())
332374
optimizer.zero_grad()
333375

334376
# Step 3. momentum update of target encoder

world_models/transforms/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,7 @@ def __call__(self, img):
5353
if torch.bernoulli(torch.tensor(self.prob)) == 0:
5454
return img
5555

56-
radius = self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min)
56+
radius = self.radius_min + torch.rand(1).item() * (
57+
self.radius_max - self.radius_min
58+
)
5759
return img.filter(ImageFilter.GaussianBlur(radius=radius))

world_models/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,6 @@ def max_episode_steps(self):
695695
def apply_masks(x, masks):
696696
all_x = []
697697
for m in masks:
698-
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.shape(-1))
699-
all_x += [torch.gather(x, 1, mask_keep)]
698+
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.shape[-1])
699+
all_x.append(torch.gather(x, 1, mask_keep))
700700
return torch.cat(all_x, dim=0)

0 commit comments

Comments
 (0)