@@ -90,36 +90,29 @@ def __init__(
9090 self .modality = Modalities .get_modality (modality )
9191 self .mask_generator = IJEPAMaskGenerator ()
9292
93- self .current_step = 0
94- self .total_steps = None
95-
9693 self .encoder = encoder
9794 self .predictor = predictor
9895
9996 self .predictor .num_patches = encoder .patch_embed .num_patches
10097 self .predictor .embed_dim = encoder .embed_dim
10198 self .predictor .num_heads = encoder .num_heads
10299
103- self .ema = ExponentialMovingAverage (
104- self .encoder ,
105- ema_decay ,
106- ema_decay_end ,
107- ema_anneal_end_step ,
108- device_id = self .device ,
100+ self .target_encoder = ExponentialMovingAverage (
101+ self .encoder , ema_decay , ema_decay_end , ema_anneal_end_step
109102 )
110103
111104 def configure_model (self ) -> None :
112105 """Configure the model."""
113- self .ema . model . to ( device = self .device , dtype = self . dtype )
106+ self .target_encoder . configure_model ( self .device )
114107
115108 def on_before_zero_grad (self , optimizer : torch .optim .Optimizer ) -> None :
116109 """Perform exponential moving average update of target encoder.
117110
118- This is done right after the optimizer step, which comes just before `zero_grad`
119- to account for gradient accumulation.
111+ This is done right after the `` optimizer. step()` , which comes just before
112+ ``optimizer.zero_grad()`` to account for gradient accumulation.
120113 """
121- if self .ema is not None :
122- self .ema .step (self .encoder )
114+ if self .target_encoder is not None :
115+ self .target_encoder .step (self .encoder )
123116
124117 def training_step (self , batch : dict [str , Any ], batch_idx : int ) -> torch .Tensor :
125118 """Perform a single training step.
@@ -200,10 +193,10 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
200193 checkpoint : dict[str, Any]
201194 The state dictionary to save the EMA state to.
202195 """
203- if self .ema is not None :
196+ if self .target_encoder is not None :
204197 checkpoint ["ema_params" ] = {
205- "decay" : self .ema .decay ,
206- "num_updates" : self .ema .num_updates ,
198+ "decay" : self .target_encoder .decay ,
199+ "num_updates" : self .target_encoder .num_updates ,
207200 }
208201
209202 def on_load_checkpoint (self , checkpoint : dict [str , Any ]) -> None :
@@ -214,12 +207,12 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
214207 checkpoint : dict[str, Any]
215208 The state dictionary to restore the EMA state from.
216209 """
217- if "ema_params" in checkpoint and self .ema is not None :
210+ if "ema_params" in checkpoint and self .target_encoder is not None :
218211 ema_params = checkpoint .pop ("ema_params" )
219- self .ema .decay = ema_params ["decay" ]
220- self .ema .num_updates = ema_params ["num_updates" ]
212+ self .target_encoder .decay = ema_params ["decay" ]
213+ self .target_encoder .num_updates = ema_params ["num_updates" ]
221214
222- self .ema .restore (self .encoder )
215+ self .target_encoder .restore (self .encoder )
223216
224217 def _shared_step (
225218 self , batch : dict [str , Any ], batch_idx : int , step_type : str
@@ -237,7 +230,7 @@ def _shared_step(
237230
238231 # Forward pass through target encoder to get h
239232 with torch .no_grad ():
240- h = self .ema .model (batch )[0 ]
233+ h = self .target_encoder .model (batch )[0 ]
241234 h = F .layer_norm (h , h .size ()[- 1 :])
242235 h_masked = apply_masks (h , predictor_masks )
243236 h_masked = repeat_interleave_batch (
@@ -252,7 +245,7 @@ def _shared_step(
252245 z_pred = self .predictor (z , encoder_masks , predictor_masks )
253246
254247 if step_type == "train" :
255- self .log ("train/ema_decay" , self .ema .decay , prog_bar = True )
248+ self .log ("train/ema_decay" , self .target_encoder .decay , prog_bar = True )
256249
257250 if self .loss_fn is not None and (
258251 step_type == "train"
0 commit comments