Skip to content

Commit 6f8d800

Browse files
committed
Refactor: Use get_parameter_dtype utility function
Replaces manual parameter iteration with the `get_parameter_dtype` helper.
1 parent e8426ba commit 6f8d800

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
get_1d_sincos_pos_embed_from_grid,
3333
)
3434
from ..modeling_outputs import Transformer2DModelOutput
35-
from ..modeling_utils import ModelMixin
35+
from ..modeling_utils import ModelMixin, get_parameter_dtype
3636
from ..normalization import FP32LayerNorm
3737

3838

@@ -198,7 +198,7 @@ def forward(
198198
):
199199
timestep = self.timesteps_proj(timestep)
200200

201-
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
201+
time_embedder_dtype = get_parameter_dtype(self.time_embedder)
202202
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
203203
timestep = timestep.to(time_embedder_dtype)
204204
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)

0 commit comments

Comments
 (0)