Skip to content

Commit e05f8cc

Browse files
authored
Fix Thermalizer diffusion step and positional encoding handling (#181)
1 parent 61f30ec commit e05f8cc

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

graph_weather/models/layers/thermalizer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,6 @@ def forward(
187187

188188
x_reshaped = x.reshape(batch, height, width, features).permute(0, 3, 1, 2)
189189

190-
if self.score_model.in_channels == features + 2:
191-
pos = self._get_position_encoding(height, width, batch, x.device)
192-
x_reshaped = torch.cat([x_reshaped, pos], dim=1)
193-
194190
if isinstance(t, int):
195191
t = torch.tensor(t, device=x.device)
196192
elif isinstance(t, torch.Tensor):
@@ -205,8 +201,15 @@ def forward(
205201
sqrt_one_minus_alpha = (1.0 - self.alphas_cumprod[t]).sqrt().to(x.device)
206202

207203
noisy_x = sqrt_alpha * x_reshaped + sqrt_one_minus_alpha * noise
208-
score = self.score_model(noisy_x)
209-
pred_x = (noisy_x - sqrt_one_minus_alpha * score) / sqrt_alpha
204+
205+
if self.score_model.in_channels == features + 2:
206+
pos = self._get_position_encoding(height, width, batch, x.device)
207+
noisy_x_with_pos = torch.cat([noisy_x, pos], dim=1)
208+
else:
209+
noisy_x_with_pos = noisy_x
210+
211+
predicted_noise = self.score_model(noisy_x_with_pos)
212+
pred_x = (noisy_x - sqrt_one_minus_alpha * predicted_noise) / sqrt_alpha
210213

211214
return pred_x.permute(0, 2, 3, 1).reshape(total_nodes, features)
212215

0 commit comments

Comments
 (0)