Skip to content

Commit eb7ae84

Browse files
committed
-autocast
1 parent 17c0e79 commit eb7ae84

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/diffusers/models/transformers/transformer_kandinsky.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,21 @@ class Kandinsky5TimeEmbeddings(nn.Module):
158158
def __init__(self, model_dim, time_dim, max_period=10000.0):
159159
super().__init__()
160160
assert model_dim % 2 == 0
161+
print(f"{model_dim=}, {time_dim=}")
161162
self.model_dim = model_dim
162163
self.max_period = max_period
163164
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
164165
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
165166
self.activation = nn.SiLU()
166167
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
167168

168-
@torch.autocast(device_type="cuda", dtype=torch.float32)
169169
def forward(self, time):
170-
args = torch.outer(time, self.freqs.to(device=time.device))
170+
original_dtype = time.dtype
171+
print(f"{original_dtype=}")
172+
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
171173
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
172-
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
173-
return time_embed
174+
time_embed = F.linear(self.activation(F.linear(time_embed, self.in_layer.weight.to(torch.float32), self.in_layer.bias.to(torch.float32))), self.out_layer.weight.to(torch.float32), self.out_layer.bias.to(torch.float32))
175+
return time_embed.to(original_dtype)
174176

175177

176178
class Kandinsky5TextEmbeddings(nn.Module):
@@ -271,7 +273,7 @@ def __init__(self, time_dim, model_dim, num_params):
271273

272274
@torch.autocast(device_type="cuda", dtype=torch.float32)
273275
def forward(self, x):
274-
return self.out_layer(self.activation(x))
276+
return F.linear(self.activation(x.to(torch.float32)), self.out_layer.weight.to(torch.float32), self.out_layer.bias.to(torch.float32)).type_as(x)
275277

276278

277279
class Kandinsky5AttnProcessor:

0 commit comments

Comments
 (0)