Skip to content

Commit 753d747

Browse files
lint
1 parent b4ac4b1 commit 753d747

File tree

4 files changed

+21
-11
lines changed

4 files changed

+21
-11
lines changed

src/weathergen/model/diffusion.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@
2323
# ----------------------------------------------------------------------------
2424

2525

26+
import logging
2627
import math
2728

2829
import torch
29-
import logging
3030

3131
from weathergen.common.config import Config
3232
from weathergen.datasets.batch import SampleMetaData
3333
from weathergen.model.engines import ForecastingEngine
3434

3535
logger = logging.getLogger(__name__)
3636

37+
3738
class DiffusionForecastEngine(torch.nn.Module):
3839
# Adopted from https://github.com/NVlabs/edm/blob/main/training/loss.py#L72
3940

@@ -56,7 +57,7 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast
5657
self.rho = self.cf.rho
5758
self.p_mean = self.cf.p_mean
5859
self.p_std = self.cf.p_std
59-
self.cur_token = None #for debugging only
60+
self.cur_token = None # for debugging only
6061

6162
def forward(
6263
self, tokens: torch.Tensor, fstep: int, meta_info: dict[str, SampleMetaData]
@@ -75,12 +76,21 @@ def forward(
7576

7677
c = 1 # TODO: add correct preconditioning (e.g., sample/s in previous time step)
7778

78-
#TODO: remove after single sample experiments
79+
# TODO: remove after single sample experiments
7980
if self.cur_token is not None:
8081
logger.info("checking single sampling")
81-
assert self.cur_token[0].shape == tokens[0].shape, 'first token shape was different between iterations – violates single sample overfitting with difference'
82-
assert torch.equal(self.cur_token[0], tokens[0]), f'first token was different between iterations – violates single sample overfitting {self.cur_token[0] - tokens[0]}'
83-
assert torch.equal(self.cur_token, tokens), 'tokens were different between iterations – violates single sample overfitting'
82+
assert self.cur_token[0].shape == tokens[0].shape, (
83+
"first token shape was different between iterations "
84+
"– violates single sample overfitting with difference"
85+
)
86+
assert torch.equal(self.cur_token[0], tokens[0]), (
87+
f"first token was different between iterations "
88+
f"– violates single sample overfitting {self.cur_token[0] - tokens[0]}"
89+
)
90+
assert torch.equal(self.cur_token, tokens), (
91+
f"tokens were different between iterations "
92+
f"– violates single sample overfitting {self.cur_token - tokens}"
93+
)
8494
self.cur_token = tokens
8595

8696
y = tokens

src/weathergen/model/encoder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def forward(self, model_params, batch):
124124

125125
return global_tokens, posteriors
126126

127-
128127
def assimilate_local(
129128
self, model_params, tokens: torch.Tensor, batch: ModelBatch
130129
) -> torch.Tensor:

src/weathergen/model/engines.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def __init__(self, cf: Config, sources_size, stream_names: list[str]) -> None:
8282
)
8383
else:
8484
raise ValueError("Unsupported embedding network type")
85-
8685

8786
def forward(self, batch, pe_embed):
8887
num_steps_input = batch.get_num_steps()
@@ -108,7 +107,7 @@ def forward(self, batch, pe_embed):
108107

109108
# embedding from physical space to per patch latent representation
110109
x_embeds += [self.embeds[stream_name](sdata).flatten(0, 1)]
111-
110+
112111
# switch from stream to cell-based ordering and apply per cell positional encoding
113112

114113
# computer scatter index across batch items and input steps

src/weathergen/model/model_interface.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def init_model_and_shard(
167167
if cf.chkpt_encoder_weights:
168168
if is_root():
169169
logger.info(
170-
f"Loading chkpt from run_id={cf.chkpt_encoder_weights}"\
170+
f"Loading chkpt from run_id={cf.chkpt_encoder_weights}"
171171
f" at mini_epoch {cf.chkpt_encoder_mini_epoch}."
172172
)
173173

@@ -196,6 +196,7 @@ def init_model_and_shard(
196196

197197
return model, model_params
198198

199+
199200
def load_encoder(cf, model, encoder_modules, device, run_id: str, mini_epoch=-1):
200201
"""Loads model state from checkpoint and checks for missing and unused keys.
201202
Args:
@@ -282,7 +283,8 @@ def load_encoder(cf, model, encoder_modules, device, run_id: str, mini_epoch=-1)
282283
logger.warning(f"Unused keys when loading model: {mkeys}")
283284

284285
return model
285-
286+
287+
286288
def load_model(cf, model, device, run_id: str, mini_epoch=-1):
287289
"""Loads model state from checkpoint and checks for missing and unused keys.
288290
Args:

0 commit comments

Comments
 (0)