From 77ffd6a10cf5e05b521b30666f69f992cf20cfe7 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Tue, 9 Dec 2025 19:41:15 +0530 Subject: [PATCH 1/3] Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning --- .../models/transformers/transformer_kandinsky.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 316e79da4fd6..ef5d8b6444f4 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -165,13 +165,15 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, time): - args = torch.outer(time, self.freqs.to(device=time.device)) + time = time.to(dtype=torch.float32) + freqs = self.freqs.to(device=time.device, dtype=torch.float32) + args = torch.outer(time, freqs) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed - + class Kandinsky5TextEmbeddings(nn.Module): def __init__(self, text_dim, model_dim): @@ -269,8 +271,8 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.weight.data.zero_() self.out_layer.bias.data.zero_() - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): + x = x.to(dtype=self.out_layer.weight.dtype) return self.out_layer(self.activation(x)) From 9afba5fee0a5485efa74a6a0672237950a782ad8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 11 Dec 2025 15:11:10 +0000 Subject: [PATCH 2/3] Apply style fixes --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index ef5d8b6444f4..b2b5baff7d95 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -173,7 +173,7 @@ def forward(self, time): time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed - + class Kandinsky5TextEmbeddings(nn.Module): def __init__(self, text_dim, model_dim): From 6eced01b7e987ecc8d3a84d16a639540f7dfaee9 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Sun, 14 Dec 2025 12:00:07 +0530 Subject: [PATCH 3/3] Fix: Remove import-time autocast in Kandinsky to prevent warnings - Removed @torch.autocast decorator from Kandinsky classes. - Implemented manual F.linear casting to ensure numerical parity with FP32. - Verified bit-exact output matches main branch. Co-authored-by: hlky --- .../transformers/transformer_kandinsky.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index b2b5baff7d95..c841cc522d81 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -166,12 +166,19 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.out_layer = nn.Linear(time_dim, time_dim, bias=True) def forward(self, time): - time = time.to(dtype=torch.float32) - freqs = self.freqs.to(device=time.device, dtype=torch.float32) - args = torch.outer(time, freqs) + args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) - time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + time_embed = F.linear( + self.activation( + F.linear( + time_embed, + self.in_layer.weight.to(torch.float32), + self.in_layer.bias.to(torch.float32), + ) + ), + self.out_layer.weight.to(torch.float32), + self.out_layer.bias.to(torch.float32), + ) return time_embed @@ -272,8 +279,11 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.bias.data.zero_() def forward(self, x): - x = x.to(dtype=self.out_layer.weight.dtype) - return self.out_layer(self.activation(x)) + return F.linear( + self.activation(x.to(torch.float32)), + self.out_layer.weight.to(torch.float32), + self.out_layer.bias.to(torch.float32), + ) class Kandinsky5AttnProcessor: