@@ -356,7 +356,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
356356        return  (emb  /  norm ).type_as (hidden_states )
357357
358358
359- class  CosmosTransformer (ModelMixin , ConfigMixin ):
359+ class  CosmosTransformer3DModel (ModelMixin , ConfigMixin ):
360360    r""" 
361361    A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos). 
362362
@@ -423,9 +423,9 @@ def __init__(
423423            hidden_size = attention_head_dim , max_size = max_size , patch_size = patch_size , rope_scale = rope_scale 
424424        )
425425
426-         self .learnable_pos_embedder  =  None 
426+         self .learnable_pos_embed  =  None 
427427        if  extra_pos_embed_type  ==  "learnable" :
428-             self .learnable_pos_embedder  =  CosmosLearnablePositionalEmbed (
428+             self .learnable_pos_embed  =  CosmosLearnablePositionalEmbed (
429429                hidden_size = hidden_size ,
430430                max_size = max_size ,
431431                patch_size = patch_size ,
@@ -477,7 +477,7 @@ def forward(
477477
478478        # 2. Generate positional embeddings 
479479        image_rotary_emb  =  self .rope (hidden_states , fps = fps )
480-         extra_pos_emb  =  self .learnable_pos_embedder (hidden_states ) if  self .config .extra_pos_embed_type  else  None 
480+         extra_pos_emb  =  self .learnable_pos_embed (hidden_states ) if  self .config .extra_pos_embed_type  else  None 
481481
482482        # 3. Patchify input 
483483        batch_size , num_channels , num_frames , height , width  =  hidden_states .shape 
0 commit comments