Skip to content

Commit f3facc7

Browse files
Sophiex/dev/name modules (ecmwf#754)
* Add names to modules as prep for freezing * Add functionality to freeze modules based on added names * Ruff * Clean up * Wrong import path * Ruff
1 parent 32b4d91 commit f3facc7

File tree

6 files changed

+52
-3
lines changed

6 files changed

+52
-3
lines changed

config/default_config.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ loss_fcts_val:
7373
batch_size_per_gpu: 1
7474
batch_size_validation_per_gpu: 1
7575

76+
# a regex that needs to fully match the name of the modules you want to freeze
77+
# e.g. ".*ERA5" will match any module whose name ends in ERA5\
78+
# encoders and decoders that exist per stream have the stream name attached at the end
79+
freeze_modules: ""
80+
7681
# training mode: "forecast" or "masking" (masked token modeling)
7782
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
7883
training_mode: "masking"

src/weathergen/model/embeddings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
norm_type="LayerNorm",
3535
embed_size_centroids=64,
3636
unembed_mode="full",
37+
stream_name="stream_embed",
3738
):
3839
"""Constructor
3940
@@ -46,6 +47,8 @@ def __init__(
4647

4748
super(StreamEmbedTransformer, self).__init__()
4849

50+
self.name = f"StreamEmbedder_{stream_name}"
51+
4952
self.num_tokens = num_tokens
5053
self.token_size = token_size
5154
self.num_channels = num_channels
@@ -194,11 +197,12 @@ def forward_columns(self, x_in, centroids):
194197

195198

196199
class StreamEmbedLinear(torch.nn.Module):
197-
def __init__(self, dim_in, dim_out):
200+
def __init__(self, dim_in, dim_out, stream_name="stream_embed"):
198201
"""Constructor"""
199202

200203
super(StreamEmbedLinear, self).__init__()
201204

205+
self.name = f"StreamEmbedder_{stream_name}"
202206
self.layer = torch.nn.Linear(dim_in, dim_out)
203207

204208
def forward(self, x):

src/weathergen/model/engines.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030

3131
class EmbeddingEngine:
32+
name: "EmbeddingEngine"
33+
3234
def __init__(self, cf: Config, sources_size) -> None:
3335
"""
3436
Initialize the EmbeddingEngine with the configuration.
@@ -47,6 +49,8 @@ def create(self) -> torch.nn.ModuleList:
4749
:return: torch.nn.ModuleList containing the embedding layers.
4850
"""
4951
for i, si in enumerate(self.cf.streams):
52+
stream_name = si.get("name", i)
53+
5054
if "diagnostic" in si and si["diagnostic"]:
5155
self.embeds.append(torch.nn.Identity())
5256
continue
@@ -66,12 +70,15 @@ def create(self) -> torch.nn.ModuleList:
6670
norm_type=self.cf.norm_type,
6771
embed_size_centroids=self.cf.embed_size_centroids,
6872
unembed_mode=self.cf.embed_unembed_mode,
73+
stream_name=stream_name,
6974
)
7075
)
7176
elif si["embed"]["net"] == "linear":
7277
self.embeds.append(
7378
StreamEmbedLinear(
74-
self.sources_size[i] * si["token_size"], self.cf.ae_local_dim_embed
79+
self.sources_size[i] * si["token_size"],
80+
self.cf.ae_local_dim_embed,
81+
stream_name=stream_name,
7582
)
7683
)
7784
else:
@@ -80,6 +87,8 @@ def create(self) -> torch.nn.ModuleList:
8087

8188

8289
class LocalAssimilationEngine:
90+
name: "LocalAssimilationEngine"
91+
8392
def __init__(self, cf: Config) -> None:
8493
"""
8594
Initialize the LocalAssimilationEngine with the configuration.
@@ -122,6 +131,8 @@ def create(self) -> torch.nn.ModuleList:
122131

123132

124133
class Local2GlobalAssimilationEngine:
134+
name: "Local2GlobalAssimilationEngine"
135+
125136
def __init__(self, cf: Config) -> None:
126137
"""
127138
Initialize the Local2GlobalAssimilationEngine with the configuration.
@@ -183,6 +194,8 @@ def create(self) -> torch.nn.ModuleList:
183194

184195

185196
class GlobalAssimilationEngine:
197+
name: "GlobalAssimilationEngine"
198+
186199
def __init__(self, cf: Config, num_healpix_cells: int) -> None:
187200
"""
188201
Initialize the GlobalAssimilationEngine with the configuration.
@@ -250,6 +263,8 @@ def create(self) -> torch.nn.ModuleList:
250263

251264

252265
class ForecastingEngine:
266+
name: "ForecastingEngine"
267+
253268
def __init__(self, cf: Config, num_healpix_cells: int) -> None:
254269
"""
255270
Initialize the ForecastingEngine with the configuration.
@@ -327,13 +342,13 @@ def init_weights_final(m):
327342

328343

329344
class EnsPredictionHead(torch.nn.Module):
330-
#########################################
331345
def __init__(
332346
self,
333347
dim_embed,
334348
dim_out,
335349
ens_num_layers,
336350
ens_size,
351+
stream_name: str,
337352
norm_type="LayerNorm",
338353
hidden_factor=2,
339354
final_activation: None | str = None,
@@ -342,6 +357,8 @@ def __init__(
342357

343358
super(EnsPredictionHead, self).__init__()
344359

360+
self.name = f"EnsPredictionHead_{stream_name}"
361+
345362
dim_internal = dim_embed * hidden_factor
346363
# norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm
347364
enl = ens_num_layers
@@ -390,6 +407,7 @@ def __init__(
390407
tr_mlp_hidden_factor,
391408
softcap,
392409
tro_type,
410+
stream_name: str,
393411
):
394412
"""
395413
Initialize the TargetPredictionEngine with the configuration.
@@ -403,6 +421,7 @@ def __init__(
403421
:param tro_type: Type of target readout (e.g., "obs_value").
404422
"""
405423
super(TargetPredictionEngineClassic, self).__init__()
424+
self.name = f"TargetPredictionEngine_{stream_name}"
406425

407426
self.cf = cf
408427
self.dims_embed = dims_embed
@@ -496,6 +515,7 @@ def __init__(
496515
tr_mlp_hidden_factor,
497516
softcap,
498517
tro_type,
518+
stream_name: str,
499519
):
500520
"""
501521
Initialize the TargetPredictionEngine with the configuration.
@@ -519,6 +539,7 @@ def __init__(
519539
LayerNorm that does not scale after the layer is applied
520540
"""
521541
super(TargetPredictionEngine, self).__init__()
542+
self.name = f"TargetPredictionEngine_{stream_name}"
522543

523544
self.cf = cf
524545
self.dims_embed = dims_embed

src/weathergen/model/layers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@ def __init__(
2727
norm_type="LayerNorm",
2828
dim_aux=None,
2929
norm_eps=1e-5,
30+
name: str | None = None,
3031
):
3132
"""Constructor"""
3233

3334
super(MLP, self).__init__()
3435

36+
if name is not None:
37+
self.name = name
38+
3539
assert num_layers >= 2
3640

3741
self.with_residual = with_residual

src/weathergen/model/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ def create(self) -> "Model":
260260
self.pred_heads = torch.nn.ModuleList()
261261

262262
for i_obs, si in enumerate(cf.streams):
263+
stream_name = si.get("name", i_obs)
264+
263265
# extract and setup relevant parameters
264266
etc = si["embed_target_coords"]
265267
tro_type = si["target_readout"]["type"] if "type" in si["target_readout"] else "token"
@@ -310,6 +312,7 @@ def create(self) -> "Model":
310312
with_residual=False,
311313
dropout_rate=dropout_rate,
312314
norm_eps=self.cf.mlp_norm_eps,
315+
stream_name=f"embed_target_coords_{stream_name}",
313316
)
314317
)
315318
else:
@@ -326,6 +329,7 @@ def create(self) -> "Model":
326329
dropout_rate=dropout_rate,
327330
norm_type=cf.norm_type,
328331
norm_eps=self.cf.mlp_norm_eps,
332+
stream_name=f"pred_adapter_kv_{stream_name}",
329333
)
330334
)
331335
else:
@@ -345,6 +349,7 @@ def create(self) -> "Model":
345349
tr_mlp_hidden_factor,
346350
softcap,
347351
tro_type,
352+
stream_name=stream_name,
348353
)
349354

350355
self.target_token_engines.append(tte)
@@ -362,6 +367,7 @@ def create(self) -> "Model":
362367
si["pred_head"]["ens_size"],
363368
norm_type=cf.norm_type,
364369
final_activation=final_activation,
370+
stream_name=stream_name,
365371
)
366372
)
367373

src/weathergen/train/trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# granted to it by virtue of its status as an intergovernmental organisation
1010
# nor does it submit to any jurisdiction.
1111

12+
import re
1213
import time
1314
from typing import Any
1415

@@ -29,6 +30,7 @@
2930
import weathergen.utils.config as config
3031
from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler
3132
from weathergen.model.model import Model, ModelParams
33+
from weathergen.model.utils import freeze_weights
3234
from weathergen.train.loss_calculator import LossCalculator
3335
from weathergen.train.lr_scheduler import LearningRateScheduler
3436
from weathergen.train.trainer_base import TrainerBase
@@ -52,6 +54,8 @@ def init(
5254
):
5355
self.cf = cf
5456

57+
self.freeze_modules = cf.get("freeze_modules", "")
58+
5559
assert cf.samples_per_epoch % cf.batch_size_per_gpu == 0
5660
assert cf.samples_per_validation % cf.batch_size_validation_per_gpu == 0
5761
assert cf.forecast_policy if cf.forecast_steps > 0 else True
@@ -182,6 +186,11 @@ def run(self, cf, run_id_contd=None, epoch_contd=None):
182186
if cf.forecast_freeze_model:
183187
self.model = self.model.freeze_weights_forecast()
184188

189+
for name, module in self.model.named_modules():
190+
name = module.name if hasattr(module, "name") else None
191+
if name is not None and re.fullmatch(self.freeze_modules, name):
192+
freeze_weights(module)
193+
185194
self.model = self.model.to(self.devices[0])
186195

187196
if cf.compile_model:

0 commit comments

Comments
 (0)