@@ -115,8 +115,16 @@ def __init__(
115115 self .theta = theta
116116 self ._causal_rope_fix = _causal_rope_fix
117117
118-
119- def _prepare_video_coords (self , batch_size : int , num_frames : int , height : int , width : int , rope_interpolation_scale : Tuple [torch .Tensor , float , float ], device : torch .device ) -> torch .Tensor :
118+ def _prepare_video_coords (
119+ self ,
120+ batch_size : int ,
121+ num_frames : int ,
122+ height : int ,
123+ width : int ,
124+ rope_interpolation_scale : Tuple [torch .Tensor , float , float ],
125+ frame_rate : float ,
126+ device : torch .device ,
127+ ) -> torch .Tensor :
120128 # Always compute rope in fp32
121129 grid_h = torch .arange (height , dtype = torch .float32 , device = device )
122130 grid_w = torch .arange (width , dtype = torch .float32 , device = device )
@@ -132,9 +140,7 @@ def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, w
132140 grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 ] * self .patch_size / self .base_width
133141 else :
134142 if not self ._causal_rope_fix :
135- grid [:, 0 :1 ] = (
136- grid [:, 0 :1 ] * rope_interpolation_scale [0 :1 ] * self .patch_size_t / self .base_num_frames
137- )
143+ grid [:, 0 :1 ] = grid [:, 0 :1 ] * rope_interpolation_scale [0 :1 ] * self .patch_size_t / self .base_num_frames
138144 else :
139145 grid [:, 0 :1 ] = (
140146 ((grid [:, 0 :1 ] - 1 ) * rope_interpolation_scale [0 :1 ] + 1 / frame_rate ).clamp (min = 0 )
@@ -145,9 +151,8 @@ def _prepare_video_coords(self, batch_size: int, num_frames: int, height: int, w
145151 grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 :3 ] * self .patch_size / self .base_width
146152
147153 grid = grid .flatten (2 , 4 ).transpose (1 , 2 )
148-
154+
149155 return grid
150-
151156
152157 def forward (
153158 self ,
@@ -162,14 +167,22 @@ def forward(
162167 batch_size = hidden_states .size (0 )
163168
164169 if video_coords is None :
165- grid = self ._prepare_video_coords (batch_size , num_frames , height , width , rope_interpolation_scale = rope_interpolation_scale , device = hidden_states .device )
170+ grid = self ._prepare_video_coords (
171+ batch_size ,
172+ num_frames ,
173+ height ,
174+ width ,
175+ rope_interpolation_scale = rope_interpolation_scale ,
176+ frame_rate = frame_rate ,
177+ device = hidden_states .device ,
178+ )
166179 else :
167180 grid = torch .stack (
168181 [
169- video_coords [:, 0 ] / self .base_num_frames ,
170- video_coords [:, 1 ] / self .base_height ,
171- video_coords [:, 2 ] / self .base_width
172- ],
182+ video_coords [:, 0 ] / self .base_num_frames ,
183+ video_coords [:, 1 ] / self .base_height ,
184+ video_coords [:, 2 ] / self .base_width ,
185+ ],
173186 dim = - 1 ,
174187 )
175188
@@ -432,8 +445,9 @@ def forward(
432445 msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
433446 deprecate ("rope_interpolation_scale" , "0.34.0" , msg )
434447
435-
436- image_rotary_emb = self .rope (hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale , video_coords )
448+ image_rotary_emb = self .rope (
449+ hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale , video_coords
450+ )
437451
438452 # convert encoder_attention_mask to a bias the same way we do for attention_mask
439453 if encoder_attention_mask is not None and encoder_attention_mask .ndim == 2 :
0 commit comments