Skip to content

Commit 669028a

Browse files
authored
Refactor iJEPA some modules (#54)
1 parent 287493f commit 669028a

File tree

7 files changed

+479
-459
lines changed

7 files changed

+479
-459
lines changed

mmlearn/datasets/processors/masking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ class IJEPAMaskGenerator:
321321
allow_overlap: bool = False
322322
enc_mask_scale: tuple[float, float] = (0.85, 1.0)
323323
pred_mask_scale: tuple[float, float] = (0.15, 0.2)
324-
aspect_ratio: tuple[float, float] = (0.75, 1.0)
324+
aspect_ratio: tuple[float, float] = (0.75, 1.5)
325325
nenc: int = 1
326326
npred: int = 4
327327

mmlearn/modules/ema.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ class ExponentialMovingAverage:
2323
The final decay value for EMA.
2424
ema_anneal_end_step : int
2525
The number of steps to anneal the decay from ``ema_decay`` to ``ema_end_decay``.
26-
device_id : Optional[Union[int, torch.device]], optional, default=None
27-
The device to move the model to.
2826
skip_keys : Optional[Union[list[str], Set[str]]], optional, default=None
2927
The keys to skip in the EMA update. These parameters will be copied directly
3028
from the model to the EMA model.
@@ -41,14 +39,9 @@ def __init__(
4139
ema_decay: float,
4240
ema_end_decay: float,
4341
ema_anneal_end_step: int,
44-
device_id: Optional[Union[int, torch.device]] = None,
4542
skip_keys: Optional[Union[list[str], Set[str]]] = None,
46-
):
43+
) -> None:
4744
self.model = self.deepcopy_model(model)
48-
self.model.requires_grad_(False)
49-
50-
if device_id is not None:
51-
self.model.to(device_id)
5245

5346
self.skip_keys: Union[list[str], set[str]] = skip_keys or set()
5447
self.num_updates = 0
@@ -57,6 +50,8 @@ def __init__(
5750
self.ema_end_decay = ema_end_decay
5851
self.ema_anneal_end_step = ema_anneal_end_step
5952

53+
self._model_configured = False
54+
6055
@staticmethod
6156
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
6257
"""Deep copy the model.
@@ -93,8 +88,23 @@ def get_annealed_rate(
9388
pct_remaining = 1 - curr_step / total_steps
9489
return end - r * pct_remaining
9590

91+
def configure_model(self, device_id: Union[int, torch.device]) -> None:
92+
"""Configure the model for EMA."""
93+
if self._model_configured:
94+
return
95+
96+
self.model.requires_grad_(False)
97+
self.model.to(device_id)
98+
99+
self._model_configured = True
100+
96101
def step(self, new_model: torch.nn.Module) -> None:
97102
"""Perform single EMA update step."""
103+
if not self._model_configured:
104+
raise RuntimeError(
105+
"Model is not configured for EMA. Call `configure_model` first."
106+
)
107+
98108
self._update_weights(new_model)
99109
self._update_ema_decay()
100110

mmlearn/tasks/ijepa.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,36 +90,29 @@ def __init__(
9090
self.modality = Modalities.get_modality(modality)
9191
self.mask_generator = IJEPAMaskGenerator()
9292

93-
self.current_step = 0
94-
self.total_steps = None
95-
9693
self.encoder = encoder
9794
self.predictor = predictor
9895

9996
self.predictor.num_patches = encoder.patch_embed.num_patches
10097
self.predictor.embed_dim = encoder.embed_dim
10198
self.predictor.num_heads = encoder.num_heads
10299

103-
self.ema = ExponentialMovingAverage(
104-
self.encoder,
105-
ema_decay,
106-
ema_decay_end,
107-
ema_anneal_end_step,
108-
device_id=self.device,
100+
self.target_encoder = ExponentialMovingAverage(
101+
self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
109102
)
110103

111104
def configure_model(self) -> None:
112105
"""Configure the model."""
113-
self.ema.model.to(device=self.device, dtype=self.dtype)
106+
self.target_encoder.configure_model(self.device)
114107

115108
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
116109
"""Perform exponential moving average update of target encoder.
117110
118-
This is done right after the optimizer step, which comes just before `zero_grad`
119-
to account for gradient accumulation.
111+
This is done right after the ``optimizer.step()`, which comes just before
112+
``optimizer.zero_grad()`` to account for gradient accumulation.
120113
"""
121-
if self.ema is not None:
122-
self.ema.step(self.encoder)
114+
if self.target_encoder is not None:
115+
self.target_encoder.step(self.encoder)
123116

124117
def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
125118
"""Perform a single training step.
@@ -200,10 +193,10 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
200193
checkpoint : dict[str, Any]
201194
The state dictionary to save the EMA state to.
202195
"""
203-
if self.ema is not None:
196+
if self.target_encoder is not None:
204197
checkpoint["ema_params"] = {
205-
"decay": self.ema.decay,
206-
"num_updates": self.ema.num_updates,
198+
"decay": self.target_encoder.decay,
199+
"num_updates": self.target_encoder.num_updates,
207200
}
208201

209202
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
@@ -214,12 +207,12 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
214207
checkpoint : dict[str, Any]
215208
The state dictionary to restore the EMA state from.
216209
"""
217-
if "ema_params" in checkpoint and self.ema is not None:
210+
if "ema_params" in checkpoint and self.target_encoder is not None:
218211
ema_params = checkpoint.pop("ema_params")
219-
self.ema.decay = ema_params["decay"]
220-
self.ema.num_updates = ema_params["num_updates"]
212+
self.target_encoder.decay = ema_params["decay"]
213+
self.target_encoder.num_updates = ema_params["num_updates"]
221214

222-
self.ema.restore(self.encoder)
215+
self.target_encoder.restore(self.encoder)
223216

224217
def _shared_step(
225218
self, batch: dict[str, Any], batch_idx: int, step_type: str
@@ -237,7 +230,7 @@ def _shared_step(
237230

238231
# Forward pass through target encoder to get h
239232
with torch.no_grad():
240-
h = self.ema.model(batch)[0]
233+
h = self.target_encoder.model(batch)[0]
241234
h = F.layer_norm(h, h.size()[-1:])
242235
h_masked = apply_masks(h, predictor_masks)
243236
h_masked = repeat_interleave_batch(
@@ -252,7 +245,7 @@ def _shared_step(
252245
z_pred = self.predictor(z, encoder_masks, predictor_masks)
253246

254247
if step_type == "train":
255-
self.log("train/ema_decay", self.ema.decay, prog_bar=True)
248+
self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)
256249

257250
if self.loss_fn is not None and (
258251
step_type == "train"

projects/ijepa/configs/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def ijepa_transforms(
1616
crop_size: int = 224,
1717
crop_scale: tuple = (0.3, 1.0),
18-
color_jitter: float = 0.0,
18+
color_jitter_strength: float = 0.0,
1919
horizontal_flip: bool = False,
2020
color_distortion: bool = False,
2121
gaussian_blur: bool = False,
@@ -31,7 +31,7 @@ def ijepa_transforms(
3131
Size of the image crop.
3232
crop_scale : tuple, default=(0.3, 1.0)
3333
Range for the random resized crop scaling.
34-
color_jitter : float, default=0.0
34+
color_jitter_strength : float, default=0.0
3535
Strength of color jitter.
3636
horizontal_flip : bool, default=False
3737
Whether to apply random horizontal flip.
@@ -89,7 +89,7 @@ def __call__(self, img):
8989
if horizontal_flip:
9090
transforms_list.append(transforms.RandomHorizontalFlip())
9191
if color_distortion:
92-
transforms_list.append(get_color_distortion(s=color_jitter))
92+
transforms_list.append(get_color_distortion(s=color_jitter_strength))
9393
if gaussian_blur:
9494
transforms_list.append(GaussianBlur(p=0.5))
9595
else:

projects/ijepa/configs/experiment/reproduce_imagenet.yaml renamed to projects/ijepa/configs/experiment/in1k_vit_small.yaml

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@ defaults:
55
- /datasets/[email protected]: ijepa_transforms
66
- /[email protected]: ImageNet
77
- /datasets/[email protected]: ijepa_transforms
8-
- /modules/[email protected]: vit_base
8+
- /modules/[email protected]: vit_small
99
- /modules/[email protected]: vit_predictor
1010
- /modules/[email protected]: AdamW
11-
- /modules/[email protected]_scheduler.scheduler: CosineAnnealingLR
11+
- /modules/[email protected]_scheduler.scheduler: linear_warmup_cosine_annealing_lr
1212
- /trainer/[email protected]_monitor: LearningRateMonitor
1313
- /trainer/[email protected]_checkpoint: ModelCheckpoint
14-
- /trainer/[email protected]_stopping: EarlyStopping
1514
- /trainer/[email protected]_summary: ModelSummary
1615
- /trainer/[email protected]: WandbLogger
1716
- override /task: IJEPA
@@ -20,6 +19,16 @@ defaults:
2019
seed: 0
2120

2221
datasets:
22+
train:
23+
transform:
24+
color_jitter_strength: 0.4
25+
horizontal_flip: true
26+
color_distortion: true
27+
gaussian_blur: false
28+
crop_scale:
29+
- 0.3
30+
- 1.0
31+
crop_size: 224
2332
val:
2433
split: val
2534
transform:
@@ -28,45 +37,50 @@ datasets:
2837
dataloader:
2938
train:
3039
batch_size: 256
31-
num_workers: 10
40+
num_workers: 8
41+
pin_memory: true
42+
drop_last: true
3243
val:
3344
batch_size: 256
34-
num_workers: 10
45+
num_workers: 8
46+
pin_memory: false
3547

3648
task:
49+
ema_decay: 0.996
50+
ema_decay_end: 1.0
51+
ema_anneal_end_step: ${task.lr_scheduler.scheduler.max_steps}
52+
predictor:
53+
kwargs:
54+
embed_dim: 384
55+
predictor_embed_dim: 384
56+
depth: 6
57+
num_heads: 6
3758
optimizer:
38-
betas:
39-
- 0.9
40-
- 0.999
4159
lr: 1.0e-3
4260
weight_decay: 0.05
43-
eps: 1.0e-8
4461
lr_scheduler:
4562
scheduler:
46-
T_max: ${trainer.max_epochs}
63+
warmup_steps: 12_510
64+
max_steps: 125_100
65+
start_factor: 0.2
66+
eta_min: 1.0e-6
4767
extras:
48-
interval: epoch
68+
interval: step
4969

5070
trainer:
51-
max_epochs: 300
52-
precision: 16-mixed
71+
max_epochs: 100
72+
precision: bf16-mixed
5373
deterministic: False
5474
benchmark: True
5575
sync_batchnorm: False # Set to True if using DDP with batchnorm
56-
log_every_n_steps: 100
57-
accumulate_grad_batches: 4
76+
log_every_n_steps: 10
77+
accumulate_grad_batches: 1
5878
check_val_every_n_epoch: 1
5979
callbacks:
6080
model_checkpoint:
61-
monitor: val/loss
62-
save_top_k: 1
6381
save_last: True
64-
every_n_epochs: 1
65-
dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on Vector SLURM environment
66-
early_stopping:
67-
monitor: val/loss
68-
patience: 5
69-
mode: min
82+
every_n_epochs: 10
83+
dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on VI's SLURM environment
7084
model_summary:
7185
max_depth: 2
7286

tests/modules/test_ema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_ema() -> None:
2020
ema_end_decay=0.9999,
2121
ema_anneal_end_step=300000,
2222
)
23-
ema.model = ema.model.cpu() # for testing purposes
23+
ema.configure_model(device_id=torch.device("cpu"))
2424

2525
# test output between model and ema model
2626
model_input = torch.rand(1, 3, 224, 224)

0 commit comments

Comments
 (0)