Skip to content

Commit 245137f

Browse files
committed
remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case
1 parent 93e36ba commit 245137f

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

src/diffusers/models/unets/unet_2d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,7 @@ def forward(
295295
# timesteps does not contain any weights and will always return f32 tensors
296296
# but time_embedding might actually be running in fp16. so we need to cast here.
297297
# there might be better ways to encapsulate this.
298-
# TODO(aryan): Need to have this reviewed
299-
t_emb = t_emb.to(dtype=sample.dtype)
298+
t_emb = t_emb.to(dtype=self.dtype)
300299
emb = self.time_embedding(t_emb)
301300

302301
if self.class_embedding is not None:
@@ -306,7 +305,7 @@ def forward(
306305
if self.config.class_embed_type == "timestep":
307306
class_labels = self.time_proj(class_labels)
308307

309-
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
308+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
310309
emb = emb + class_emb
311310
elif self.class_embedding is None and class_labels is not None:
312311
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,8 +2133,7 @@ def forward(
21332133
# timesteps does not contain any weights and will always return f32 tensors
21342134
# but time_embedding might actually be running in fp16. so we need to cast here.
21352135
# there might be better ways to encapsulate this.
2136-
# TODO(aryan): Need to have this reviewed
2137-
t_emb = t_emb.to(dtype=sample.dtype)
2136+
t_emb = t_emb.to(dtype=self.dtype)
21382137

21392138
emb = self.time_embedding(t_emb, timestep_cond)
21402139
aug_emb = None

0 commit comments

Comments
 (0)