@@ -260,7 +260,6 @@ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
260260 return torch .cat (result , dim = - 1 ).to (device )
261261
262262 def forward (self , hidden_states : torch .Tensor , attention_mask : torch .Tensor ):
263- # Get batch info and dimensions
264263 batch_size , channels , height , width = hidden_states .shape
265264 p = self .patch_size
266265 post_patch_height , post_patch_width = height // p , width // p
@@ -276,11 +275,11 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
276275 position_ids = torch .zeros (batch_size , max_seq_len , 3 , dtype = torch .int32 , device = device )
277276
278277 for i , (cap_seq_len , seq_len ) in enumerate (zip (l_effective_cap_len , seq_lengths )):
279- # Set caption positions
278+ # add caption position ids
280279 position_ids [i , :cap_seq_len , 0 ] = torch .arange (cap_seq_len , dtype = torch .int32 , device = device )
281280 position_ids [i , cap_seq_len :seq_len , 0 ] = cap_seq_len
282281
283- # Set image patch positions
282+ # add image position ids
284283 row_ids = (
285284 torch .arange (post_patch_height , dtype = torch .int32 , device = device )
286285 .view (- 1 , 1 )
@@ -296,10 +295,10 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
296295 position_ids [i , cap_seq_len :seq_len , 1 ] = row_ids
297296 position_ids [i , cap_seq_len :seq_len , 2 ] = col_ids
298297
299- # Get frequencies
298+ # Get combined rotary embeddings
300299 freqs_cis = self ._get_freqs_cis (position_ids )
301300
302- # Split frequencies for captions and images
301+ # create separate rotary embeddings for captions and images
303302 cap_freqs_cis = torch .zeros (
304303 batch_size , encoder_seq_len , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
305304 )
@@ -311,7 +310,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
311310 cap_freqs_cis [i , :cap_seq_len ] = freqs_cis [i , :cap_seq_len ]
312311 img_freqs_cis [i , :image_seq_len ] = freqs_cis [i , cap_seq_len :seq_len ]
313312
314- # patch embeddings
313+ # image patch embeddings
315314 hidden_states = (
316315 hidden_states .view (batch_size , channels , post_patch_height , p , post_patch_width , p )
317316 .permute (0 , 2 , 4 , 3 , 5 , 1 )
0 commit comments