Skip to content

Commit f03ddcb

Browse files
Added IJEPA model
1 parent 1f83db2 commit f03ddcb

File tree

14 files changed

+2073
-1
lines changed

14 files changed

+2073
-1
lines changed

jepa_try.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from world_models.models.jepa_agent import JEPAAgent
2+
3+
agent = JEPAAgent(
4+
folder="results/jepa_try",
5+
write_tag="jepa_try",
6+
)
7+
agent.train()
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
from typing import Tuple, Dict, Any
3+
4+
5+
class JEPAConfig:
6+
"""
7+
Minimal configuration container for JEPA training.
8+
Converts to the nested dict expected by `train_jepa.main`.
9+
"""
10+
11+
def __init__(self):
12+
# meta
13+
self.use_bfloat16: bool = False
14+
self.model_name: str = "vit_base"
15+
self.load_checkpoint: bool = False
16+
self.read_checkpoint: str | None = None
17+
self.copy_data: bool = False
18+
self.pred_depth: int = 6
19+
self.pred_emb_dim: int = 384
20+
21+
# data
22+
self.use_gaussian_blur: bool = True
23+
self.use_horizontal_flip: bool = True
24+
self.use_color_distortion: bool = True
25+
self.color_jitter_strength: float = 0.5
26+
self.batch_size: int = 64
27+
self.pin_mem: bool = True
28+
self.num_workers: int = 8
29+
self.root_path: str = os.environ.get("IMAGENET_ROOT", "/data/imagenet")
30+
self.image_folder: str = "train"
31+
self.crop_size: int = 224
32+
self.crop_scale: Tuple[float, float] = (0.67, 1.0)
33+
34+
# mask
35+
self.allow_overlap: bool = False
36+
self.patch_size: int = 16
37+
self.num_enc_masks: int = 1
38+
self.min_keep: int = 4
39+
self.enc_mask_scale: Tuple[float, float] = (0.15, 0.2)
40+
self.num_pred_masks: int = 1
41+
self.pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
42+
self.aspect_ratio: Tuple[float, float] = (0.75, 1.5)
43+
44+
# optimization
45+
self.ema: Tuple[float, float] = (0.996, 1.0)
46+
self.ipe_scale: float = 1.0
47+
self.weight_decay: float = 0.04
48+
self.final_weight_decay: float = 0.4
49+
self.epochs: int = 300
50+
self.warmup: int = 40
51+
self.start_lr: float = 1e-6
52+
self.lr: float = 1.5e-4
53+
self.final_lr: float = 1e-6
54+
55+
# logging
56+
self.folder: str = "results/jepa"
57+
self.write_tag: str = "jepa_run"
58+
59+
def to_dict(self) -> Dict[str, Dict[str, Any]]:
60+
return {
61+
"meta": {
62+
"use_bfloat16": self.use_bfloat16,
63+
"model_name": self.model_name,
64+
"load_checkpoint": self.load_checkpoint,
65+
"read_checkpoint": self.read_checkpoint,
66+
"copy_data": self.copy_data,
67+
"pred_depth": self.pred_depth,
68+
"pred_emb_dim": self.pred_emb_dim,
69+
},
70+
"data": {
71+
"use_gaussian_blur": self.use_gaussian_blur,
72+
"use_horizontal_flip": self.use_horizontal_flip,
73+
"use_color_distortion": self.use_color_distortion,
74+
"color_jitter_strength": self.color_jitter_strength,
75+
"batch_size": self.batch_size,
76+
"pin_mem": self.pin_mem,
77+
"num_workers": self.num_workers,
78+
"root_path": self.root_path,
79+
"image_folder": self.image_folder,
80+
"crop_size": self.crop_size,
81+
"crop_scale": self.crop_scale,
82+
},
83+
"mask": {
84+
"allow_overlap": self.allow_overlap,
85+
"patch_size": self.patch_size,
86+
"num_enc_masks": self.num_enc_masks,
87+
"min_keep": self.min_keep,
88+
"enc_mask_scale": self.enc_mask_scale,
89+
"num_pred_masks": self.num_pred_masks,
90+
"pred_mask_scale": self.pred_mask_scale,
91+
"aspect_ratio": self.aspect_ratio,
92+
},
93+
"optimization": {
94+
"ema": self.ema,
95+
"ipe_scale": self.ipe_scale,
96+
"weight_decay": self.weight_decay,
97+
"final_weight_decay": self.final_weight_decay,
98+
"epochs": self.epochs,
99+
"warmup": self.warmup,
100+
"start_lr": self.start_lr,
101+
"lr": self.lr,
102+
"final_lr": self.final_lr,
103+
},
104+
"logging": {
105+
"folder": self.folder,
106+
"write_tag": self.write_tag,
107+
},
108+
}
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import os
2+
import subprocess
3+
import time
4+
5+
import numpy as np
6+
7+
from logging import getLogger
8+
9+
import torch
10+
import torchvision
11+
12+
_GLOBAL_SEED = 0
13+
logger = getLogger()
14+
15+
16+
def make_imagenet1k(
17+
transform,
18+
batch_size,
19+
collator=None,
20+
pin_mem=True,
21+
num_workers=8,
22+
world_size=1,
23+
rank=0,
24+
root_path=None,
25+
image_folder=None,
26+
training=True,
27+
copy_data=False,
28+
drop_last=True,
29+
subset_file=None,
30+
):
31+
dataset = ImageNet(
32+
root=root_path,
33+
image_folder=image_folder,
34+
transform=transform,
35+
train=training,
36+
copy_data=copy_data,
37+
index_targets=False,
38+
)
39+
if subset_file is not None:
40+
dataset = ImageNetSubset(dataset, subset_file)
41+
logger.info("ImageNet dataset created")
42+
dist_sampler = torch.utils.data.distributed.DistributedSampler(
43+
dataset=dataset, num_replicas=world_size, rank=rank
44+
)
45+
data_loader = torch.utils.data.DataLoader(
46+
dataset,
47+
collate_fn=collator,
48+
sampler=dist_sampler,
49+
batch_size=batch_size,
50+
drop_last=drop_last,
51+
pin_memory=pin_mem,
52+
num_workers=num_workers,
53+
persistent_workers=False,
54+
)
55+
logger.info("ImageNet unsupervised data loader created")
56+
57+
return dataset, data_loader, dist_sampler
58+
59+
60+
class ImageNet(torchvision.datasets.ImageFolder):
61+
62+
def __init__(
63+
self,
64+
root,
65+
image_folder="imagenet_full_size/061417/",
66+
tar_file="imagenet_full_size-061417.tar.gz",
67+
transform=None,
68+
train=True,
69+
job_id=None,
70+
local_rank=None,
71+
copy_data=True,
72+
index_targets=False,
73+
):
74+
"""
75+
ImageNet
76+
77+
Dataset wrapper (can copy data locally to machine)
78+
79+
:param root: root network directory for ImageNet data
80+
:param image_folder: path to images inside root network directory
81+
:param tar_file: zipped image_folder inside root network directory
82+
:param train: whether to load train data (or validation)
83+
:param job_id: scheduler job-id used to create dir on local machine
84+
:param copy_data: whether to copy data from network file locally
85+
:param index_targets: whether to index the id of each labeled image
86+
"""
87+
88+
suffix = "train/" if train else "val/"
89+
data_path = None
90+
if copy_data:
91+
logger.info("copying data locally")
92+
data_path = copy_imgnt_locally(
93+
root=root,
94+
suffix=suffix,
95+
image_folder=image_folder,
96+
tar_file=tar_file,
97+
job_id=job_id,
98+
local_rank=local_rank,
99+
)
100+
if (not copy_data) or (data_path is None):
101+
data_path = os.path.join(root, image_folder, suffix)
102+
logger.info(f"data-path {data_path}")
103+
104+
super(ImageNet, self).__init__(root=data_path, transform=transform)
105+
logger.info("Initialized ImageNet")
106+
107+
if index_targets:
108+
self.targets = []
109+
for sample in self.samples:
110+
self.targets.append(sample[1])
111+
self.targets = np.array(self.targets)
112+
self.samples = np.array(self.samples)
113+
114+
mint = None
115+
self.target_indices = []
116+
for t in range(len(self.classes)):
117+
indices = np.squeeze(np.argwhere(self.targets == t)).tolist()
118+
self.target_indices.append(indices)
119+
mint = len(indices) if mint is None else min(mint, len(indices))
120+
logger.debug(f"num-labeled target {t} {len(indices)}")
121+
logger.info(f"min. labeled indices {mint}")
122+
123+
124+
class ImageNetSubset(object):
125+
126+
def __init__(self, dataset, subset_file):
127+
"""
128+
ImageNetSubset
129+
130+
:param dataset: ImageNet dataset object
131+
:param subset_file: '.txt' file containing IDs of IN1K images to keep
132+
"""
133+
self.dataset = dataset
134+
self.subset_file = subset_file
135+
self.filter_dataset_(subset_file)
136+
137+
def filter_dataset_(self, subset_file):
138+
"""Filter self.dataset to a subset"""
139+
root = self.dataset.root
140+
class_to_idx = self.dataset.class_to_idx
141+
# -- update samples to subset of IN1k targets/samples
142+
new_samples = []
143+
logger.info(f"Using {subset_file}")
144+
with open(subset_file, "r") as rfile:
145+
for line in rfile:
146+
class_name = line.split("_")[0]
147+
target = class_to_idx[class_name]
148+
img = line.split("\n")[0]
149+
new_samples.append((os.path.join(root, class_name, img), target))
150+
self.samples = new_samples
151+
152+
@property
153+
def classes(self):
154+
return self.dataset.classes
155+
156+
def __len__(self):
157+
return len(self.samples)
158+
159+
def __getitem__(self, index):
160+
path, target = self.samples[index]
161+
img = self.dataset.loader(path)
162+
if self.dataset.transform is not None:
163+
img = self.dataset.transform(img)
164+
if self.dataset.target_transform is not None:
165+
target = self.dataset.target_transform(target)
166+
return img, target
167+
168+
169+
def copy_imgnt_locally(
170+
root,
171+
suffix,
172+
image_folder="imagenet_full_size/061417/",
173+
tar_file="imagenet_full_size-061417.tar.gz",
174+
job_id=None,
175+
local_rank=None,
176+
):
177+
if job_id is None:
178+
try:
179+
job_id = os.environ["SLURM_JOBID"]
180+
except Exception:
181+
logger.info("No job-id, will load directly from network file")
182+
return None
183+
184+
if local_rank is None:
185+
try:
186+
local_rank = int(os.environ["SLURM_LOCALID"])
187+
except Exception:
188+
logger.info("No job-id, will load directly from network file")
189+
return None
190+
191+
source_file = os.path.join(root, tar_file)
192+
target = f"/scratch/slurm_tmpdir/{job_id}/"
193+
target_file = os.path.join(target, tar_file)
194+
data_path = os.path.join(target, image_folder, suffix)
195+
logger.info(f"{source_file}\n{target}\n{target_file}\n{data_path}")
196+
197+
tmp_sgnl_file = os.path.join(target, "copy_signal.txt")
198+
199+
if not os.path.exists(data_path):
200+
if local_rank == 0:
201+
commands = [["tar", "-xf", source_file, "-C", target]]
202+
for cmnd in commands:
203+
start_time = time.time()
204+
logger.info(f"Executing {cmnd}")
205+
subprocess.run(cmnd)
206+
logger.info(f"Cmnd took {(time.time()-start_time)/60.} min.")
207+
with open(tmp_sgnl_file, "+w") as f:
208+
print("Done copying locally.", file=f)
209+
else:
210+
while not os.path.exists(tmp_sgnl_file):
211+
time.sleep(60)
212+
logger.info(f"{local_rank}: Checking {tmp_sgnl_file}")
213+
214+
return data_path

0 commit comments

Comments
 (0)