@@ -158,19 +158,21 @@ class Kandinsky5TimeEmbeddings(nn.Module):
158158 def __init__ (self , model_dim , time_dim , max_period = 10000.0 ):
159159 super ().__init__ ()
160160 assert model_dim % 2 == 0
161+ print (f"{ model_dim = } , { time_dim = } " )
161162 self .model_dim = model_dim
162163 self .max_period = max_period
163164 self .freqs = get_freqs (self .model_dim // 2 , self .max_period )
164165 self .in_layer = nn .Linear (model_dim , time_dim , bias = True )
165166 self .activation = nn .SiLU ()
166167 self .out_layer = nn .Linear (time_dim , time_dim , bias = True )
167168
168- @torch .autocast (device_type = "cuda" , dtype = torch .float32 )
169169 def forward (self , time ):
170- args = torch .outer (time , self .freqs .to (device = time .device ))
170+ original_dtype = time .dtype
171+ print (f"{ original_dtype = } " )
172+ args = torch .outer (time .to (torch .float32 ), self .freqs .to (device = time .device ))
171173 time_embed = torch .cat ([torch .cos (args ), torch .sin (args )], dim = - 1 )
172- time_embed = self . out_layer (self .activation (self .in_layer ( time_embed ) ))
173- return time_embed
174+ time_embed = F . linear (self .activation (F . linear ( time_embed , self .in_layer . weight . to ( torch . float32 ), self . in_layer . bias . to ( torch . float32 ))), self . out_layer . weight . to ( torch . float32 ), self . out_layer . bias . to ( torch . float32 ))
175+ return time_embed . to ( original_dtype )
174176
175177
176178class Kandinsky5TextEmbeddings (nn .Module ):
@@ -271,7 +273,7 @@ def __init__(self, time_dim, model_dim, num_params):
271273
272274 @torch .autocast (device_type = "cuda" , dtype = torch .float32 )
273275 def forward (self , x ):
274- return self . out_layer (self .activation (x ) )
276+ return F . linear (self .activation (x . to ( torch . float32 )), self . out_layer . weight . to ( torch . float32 ), self . out_layer . bias . to ( torch . float32 )). type_as ( x )
275277
276278
277279class Kandinsky5AttnProcessor :
0 commit comments