Skip to content

Commit 03d2587

Browse files
Merge pull request #8 from ParamThakkar123/add_jepa
Added IJEPA model
2 parents 266aa41 + 0ac3cac commit 03d2587

File tree

16 files changed

+2220
-1
lines changed

16 files changed

+2220
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ results/
44
dist/*
55
pytorch_world.egg-info/
66
world_models/models/data/
7+
cifar

jepa_try.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from world_models.models.jepa_agent import JEPAAgent
2+
3+
if __name__ == "__main__":
4+
agent = JEPAAgent(
5+
dataset="cifar10",
6+
root_path=r"E:\pytorch-world\cifar",
7+
download=True,
8+
folder="results/cifar_jepa",
9+
write_tag="cifar_jepa",
10+
batch_size=16,
11+
pin_mem=False,
12+
crop_size=32,
13+
patch_size=4,
14+
enc_mask_scale=(0.05, 0.15),
15+
pred_mask_scale=(0.05, 0.15),
16+
min_keep=1,
17+
allow_overlap=True,
18+
num_workers=0,
19+
epochs=25,
20+
)
21+
agent.train()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
from typing import Tuple, Dict, Any
3+
4+
5+
class JEPAConfig:
6+
"""
7+
Minimal configuration container for JEPA training.
8+
Converts to the nested dict expected by `train_jepa.main`.
9+
"""
10+
11+
def __init__(self):
12+
# meta
13+
self.use_bfloat16: bool = False
14+
self.model_name: str = "vit_base"
15+
self.load_checkpoint: bool = False
16+
self.read_checkpoint: str | None = None
17+
self.copy_data: bool = False
18+
self.pred_depth: int = 6
19+
self.pred_emb_dim: int = 384
20+
21+
# data
22+
self.dataset: str = "imagenet" # "imagenet" or "imagefolder"
23+
self.val_split: float | None = (
24+
None # optional fraction for val split when using imagefolder
25+
)
26+
self.use_gaussian_blur: bool = True
27+
self.use_horizontal_flip: bool = True
28+
self.use_color_distortion: bool = True
29+
self.color_jitter_strength: float = 0.5
30+
self.batch_size: int = 64
31+
self.pin_mem: bool = True
32+
self.num_workers: int = 8
33+
self.root_path: str = os.environ.get("IMAGENET_ROOT", "/data/imagenet")
34+
self.image_folder: str = "train"
35+
self.crop_size: int = 224
36+
self.crop_scale: Tuple[float, float] = (0.67, 1.0)
37+
self.download: bool = False # allow CIFAR10 download if missing
38+
39+
# mask
40+
self.allow_overlap: bool = False
41+
self.patch_size: int = 16
42+
self.num_enc_masks: int = 1
43+
self.min_keep: int = 4
44+
self.enc_mask_scale: Tuple[float, float] = (0.15, 0.2)
45+
self.num_pred_masks: int = 1
46+
self.pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
47+
self.aspect_ratio: Tuple[float, float] = (0.75, 1.5)
48+
49+
# optimization
50+
self.ema: Tuple[float, float] = (0.996, 1.0)
51+
self.ipe_scale: float = 1.0
52+
self.weight_decay: float = 0.04
53+
self.final_weight_decay: float = 0.4
54+
self.epochs: int = 300
55+
self.warmup: int = 40
56+
self.start_lr: float = 1e-6
57+
self.lr: float = 1.5e-4
58+
self.final_lr: float = 1e-6
59+
60+
# logging
61+
self.folder: str = "results/jepa"
62+
self.write_tag: str = "jepa_run"
63+
64+
def to_dict(self) -> Dict[str, Dict[str, Any]]:
65+
return {
66+
"meta": {
67+
"use_bfloat16": self.use_bfloat16,
68+
"model_name": self.model_name,
69+
"load_checkpoint": self.load_checkpoint,
70+
"read_checkpoint": self.read_checkpoint,
71+
"copy_data": self.copy_data,
72+
"pred_depth": self.pred_depth,
73+
"pred_emb_dim": self.pred_emb_dim,
74+
},
75+
"data": {
76+
"dataset": self.dataset,
77+
"val_split": self.val_split,
78+
"use_gaussian_blur": self.use_gaussian_blur,
79+
"use_horizontal_flip": self.use_horizontal_flip,
80+
"use_color_distortion": self.use_color_distortion,
81+
"color_jitter_strength": self.color_jitter_strength,
82+
"batch_size": self.batch_size,
83+
"pin_mem": self.pin_mem,
84+
"num_workers": self.num_workers,
85+
"root_path": self.root_path,
86+
"image_folder": self.image_folder,
87+
"crop_size": self.crop_size,
88+
"crop_scale": self.crop_scale,
89+
"download": self.download, # new
90+
},
91+
"mask": {
92+
"allow_overlap": self.allow_overlap,
93+
"patch_size": self.patch_size,
94+
"num_enc_masks": self.num_enc_masks,
95+
"min_keep": self.min_keep,
96+
"enc_mask_scale": self.enc_mask_scale,
97+
"num_pred_masks": self.num_pred_masks,
98+
"pred_mask_scale": self.pred_mask_scale,
99+
"aspect_ratio": self.aspect_ratio,
100+
},
101+
"optimization": {
102+
"ema": self.ema,
103+
"ipe_scale": self.ipe_scale,
104+
"weight_decay": self.weight_decay,
105+
"final_weight_decay": self.final_weight_decay,
106+
"epochs": self.epochs,
107+
"warmup": self.warmup,
108+
"start_lr": self.start_lr,
109+
"lr": self.lr,
110+
"final_lr": self.final_lr,
111+
},
112+
"logging": {
113+
"folder": self.folder,
114+
"write_tag": self.write_tag,
115+
},
116+
}

world_models/datasets/cifar10.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
from torchvision.datasets import CIFAR10
3+
from logging import getLogger
4+
5+
logger = getLogger()
6+
7+
8+
def make_cifar10(
9+
transform,
10+
batch_size,
11+
collator=None,
12+
pin_mem=True,
13+
num_workers=8,
14+
world_size=1,
15+
rank=0,
16+
root_path=None,
17+
drop_last=True,
18+
train=True,
19+
download=False, # new
20+
):
21+
dataset = CIFAR10(
22+
root=root_path,
23+
train=train,
24+
download=download,
25+
transform=transform,
26+
)
27+
dist_sampler = torch.utils.data.distributed.DistributedSampler(
28+
dataset=dataset, num_replicas=world_size, rank=rank
29+
)
30+
data_loader = torch.utils.data.DataLoader(
31+
dataset,
32+
collate_fn=collator,
33+
sampler=dist_sampler,
34+
batch_size=batch_size,
35+
drop_last=drop_last,
36+
pin_memory=pin_mem,
37+
num_workers=num_workers,
38+
persistent_workers=False,
39+
)
40+
logger.info("CIFAR10 data loader created")
41+
return dataset, data_loader, dist_sampler

0 commit comments

Comments
 (0)