@@ -187,10 +187,6 @@ def forward(
187187
188188 x_reshaped = x .reshape (batch , height , width , features ).permute (0 , 3 , 1 , 2 )
189189
190- if self .score_model .in_channels == features + 2 :
191- pos = self ._get_position_encoding (height , width , batch , x .device )
192- x_reshaped = torch .cat ([x_reshaped , pos ], dim = 1 )
193-
194190 if isinstance (t , int ):
195191 t = torch .tensor (t , device = x .device )
196192 elif isinstance (t , torch .Tensor ):
@@ -205,8 +201,15 @@ def forward(
205201 sqrt_one_minus_alpha = (1.0 - self .alphas_cumprod [t ]).sqrt ().to (x .device )
206202
207203 noisy_x = sqrt_alpha * x_reshaped + sqrt_one_minus_alpha * noise
208- score = self .score_model (noisy_x )
209- pred_x = (noisy_x - sqrt_one_minus_alpha * score ) / sqrt_alpha
204+
205+ if self .score_model .in_channels == features + 2 :
206+ pos = self ._get_position_encoding (height , width , batch , x .device )
207+ noisy_x_with_pos = torch .cat ([noisy_x , pos ], dim = 1 )
208+ else :
209+ noisy_x_with_pos = noisy_x
210+
211+ predicted_noise = self .score_model (noisy_x_with_pos )
212+ pred_x = (noisy_x - sqrt_one_minus_alpha * predicted_noise ) / sqrt_alpha
210213
211214 return pred_x .permute (0 , 2 , 3 , 1 ).reshape (total_nodes , features )
212215
0 commit comments