Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
970c4dc
add pin mem to IOReaderData
Dec 24, 2025
5c566df
add pin mem to sample & modelbatch class
Dec 24, 2025
e85309d
add pin mem to stream data
Dec 24, 2025
ac3b089
add pin mem to training loop
Dec 24, 2025
c3fc9a7
run /scripts/actions.sh lint
Dec 29, 2025
7ac3b3e
run ./scripts/actions.sh unit-test
Dec 29, 2025
a65f561
ignore check torch import in package
Dec 29, 2025
98f4e0b
move pinning to MultiStreamDataSampler
Dec 30, 2025
bc80b26
add _pin_tensor & _pin_tensor_list helper func
Dec 30, 2025
8f98482
ruff the code
Dec 30, 2025
ea8f16c
move back pin mem. to train loop
Dec 30, 2025
61433eb
Remove the ignore-import-error rule and revert to the state before th…
Dec 30, 2025
48c51e3
create protocol for pinnable obj
Dec 30, 2025
dc40a2f
remove pin_mem from IOReaderData class
Dec 30, 2025
36c4b9c
add pin_memory to Trainer.validate
Dec 30, 2025
ebec481
remove pin_memory from loader_params
Dec 30, 2025
62c4e02
Rever export/export_inference.py to state before c3fc9a78
Dec 30, 2025
6a22234
change name
Jan 6, 2026
3796bc8
revise Pinnable class description
Jan 6, 2026
e29160a
Merge branch 'ecmwf:develop' into javad/dev/manual-mem-pinning-1399
javak87 Jan 7, 2026
7fe5b44
add memory_pinning in config, train & va loop
Jan 13, 2026
20944f3
Merge branch 'develop' into javad/dev/manual-mem-pinning-1399
javak87 Jan 14, 2026
08078e8
use getattr to avoid CICD warning
Jan 14, 2026
bd57cf4
use setattr to avoid CICD warning
Jan 14, 2026
503d742
disable pylint for self.source_tokens_lens
Jan 14, 2026
71461b6
Merge branch 'develop' into javad/dev/manual-mem-pinning-1399
clessig Jan 14, 2026
a31d6ea
changes based on #1615
Jan 16, 2026
7a98a08
Merge branch 'javad/dev/manual-mem-pinning-1399' of https://github.co…
Jan 16, 2026
039121b
Merge branch 'develop' into javad/dev/manual-mem-pinning-1399
javak87 Jan 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion config/config_physical_jepa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ data_loading :
num_workers: 12
rng_seed: ???

# pin GPU memory for faster transfer; it is possible that enabling memory_pinning with
# FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error.
# If this happens, you can disable the flag, but performance will drop on GH200.
memory_pinning: True


# config for training
training_config:
Expand Down Expand Up @@ -320,4 +325,4 @@ wgtags:
# *** Experiment-specific tags ***
# All extra tags (including lists, dictionaries, etc.) are treated
# as strings by mlflow, so treat all extra tags as simple string key: value pairs.
grid: null
grid: null
5 changes: 5 additions & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ data_loading :
rng_seed: ???
repeat_data_in_mini_epoch : False

# pin GPU memory for faster transfer; it is possible that enabling memory_pinning with
# FSDP2 + DINOv2 can cause the job to hang and trigger a PyTorch timeout error.
# If this happens, you can disable the flag, but performance will drop on GH200.
memory_pinning: True


# config for training
training_config:
Expand Down
43 changes: 43 additions & 0 deletions src/weathergen/datasets/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ class Sample:
# keys: stream_name, values: StreamData
streams_data: dict[str, StreamData | None]

def pin_memory(self):
"""Pin all tensors in this Sample to CPU pinned memory"""

# Pin StreamData objects in streams_data dict
if hasattr(self, "streams_data") and isinstance(self.streams_data, dict):
for _stream_name, stream_data in self.streams_data.items():
if stream_data is not None and hasattr(stream_data, "pin_memory"):
stream_data.pin_memory()

# Pin tensors in meta_info
if hasattr(self, "meta_info") and isinstance(self.meta_info, dict):
for _key, meta_data in self.meta_info.items():
if isinstance(meta_data, SampleMetaData):
# Pin mask tensor
if meta_data.mask is not None and isinstance(meta_data.mask, torch.Tensor):
meta_data.mask = meta_data.mask.pin_memory()

return self

def __init__(self, streams: dict) -> None:
self.meta_info = {}

Expand Down Expand Up @@ -156,6 +175,19 @@ def get_device(self) -> str | torch.device:
"""
return self.device

def pin_memory(self):
"""Pin all tensors in this batch to CPU pinned memory"""

# pin all samples
for sample in self.samples:
sample.pin_memory()

# pin source_tokens_lens
if isinstance(self.tokens_lens, torch.Tensor):
self.tokens_lens = self.tokens_lens.pin_memory()

return self


class ModelBatch:
"""
Expand Down Expand Up @@ -186,6 +218,17 @@ def __init__(self, streams: dict, num_source_samples: int, num_target_samples: i
self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32)
self.target2source_matching_idxs = [[] for _ in range(num_target_samples)]

def pin_memory(self):
"""Pin all tensors in this batch to CPU pinned memory"""

# pin source samples
self.source_samples.pin_memory()

# pin target samples
self.target_samples.pin_memory()

return self

def to_device(self, device): # -> ModelBatch
"""
Move batch to device
Expand Down
42 changes: 42 additions & 0 deletions src/weathergen/datasets/memory_pinning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Protocol, runtime_checkable

import torch

from weathergen.common.io import IOReaderData


@runtime_checkable
class Pinnable(Protocol):
"""
Protocol that allows the pytorch content of a data structure
to be pinned to the memory of the current accelerator.

This extends the pin_memory() capability of a torch Tensor
to other classes.

It is blocking.
"""

def pin_memory(self): ...


def pin_object(obj: Pinnable | torch.Tensor | IOReaderData | list | dict | None):
if obj is None:
return
elif isinstance(obj, torch.Tensor | Pinnable):
obj.pin_memory()
elif isinstance(obj, IOReaderData):
# Special case: IOReaderData is in common package and can't have torch deps
# Note: These SHOULD be numpy arrays per the type hints, but might be tensors
pin_object(obj.coords)
pin_object(obj.data)
pin_object(obj.geoinfos)

elif isinstance(obj, list):
# Assume the list is a list of potentially pinnable objects and traverse it.
for e in obj:
pin_object(e)
elif isinstance(obj, dict):
# Assume the values are pinnable.
for e in obj.values():
pin_object(e)
56 changes: 56 additions & 0 deletions src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,37 @@
from weathergen.common.io import IOReaderData


def _pin_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Pin a tensor to CPU pinned memory.

Parameters
----------
tensor : torch.Tensor

Returns
-------
torch.Tensor
The pinned tensor.
"""
return tensor.pin_memory() if isinstance(tensor, torch.Tensor) else tensor


def _pin_tensor_list(tensor_list: list) -> list:
"""Pin all tensors in a list to CPU pinned memory.

Parameters
----------
tensor_list : list
List of tensors (or other objects) to pin.

Returns
-------
list
List with all torch.Tensor elements pinned to CPU pinned memory.
"""
return [_pin_tensor(t) for t in tensor_list]


class StreamData:
"""
StreamData object that encapsulates all data the model ingests for one batch item
Expand Down Expand Up @@ -75,6 +106,31 @@ def __init__(self, idx: int, input_steps: int, forecast_steps: int, healpix_cell
self.source_idxs_embed = [torch.tensor([]) for _ in range(self.input_steps)]
self.source_idxs_embed_pe = [torch.tensor([]) for _ in range(self.input_steps)]

def pin_memory(self):
"""Pin all tensors in this StreamData object to CPU pinned memory"""

# Pin target tensors
self.target_coords = _pin_tensor_list(self.target_coords)
self.target_coords_lens = _pin_tensor_list(self.target_coords_lens)
self.target_tokens = _pin_tensor_list(self.target_tokens)
self.target_tokens_lens = _pin_tensor_list(self.target_tokens_lens)
self.idxs_inv = _pin_tensor_list(self.idxs_inv)
self.target_coords_raw = _pin_tensor_list(self.target_coords_raw)

# Pin source tensors
self.source_tokens_cells = _pin_tensor_list(self.source_tokens_cells)
self.source_tokens_lens = _pin_tensor_list(self.source_tokens_lens)
self.source_idxs_embed = _pin_tensor_list(self.source_idxs_embed)
self.source_idxs_embed_pe = _pin_tensor_list(self.source_idxs_embed_pe)

# Pin source_raw (list of IOReaderData objects)
if hasattr(self, "source_raw"):
for raw_data in self.source_raw:
if raw_data is not None and hasattr(raw_data, "pin_memory"):
raw_data.pin_memory()

return self

def to_device(self, device: str) -> None:
"""
Move data to GPU
Expand Down
10 changes: 8 additions & 2 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def inference(self, cf, devices, run_id_contd, mini_epoch_contd):
"batch_sampler": None,
"shuffle": False,
"num_workers": loader_num_workers,
"pin_memory": True,
}
self.data_loader_validation = torch.utils.data.DataLoader(
self.dataset, **loader_params, sampler=None
Expand Down Expand Up @@ -226,7 +225,6 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None):
"batch_sampler": None,
"shuffle": False,
"num_workers": cf.data_loading.num_workers,
"pin_memory": True,
}
self.data_loader = torch.utils.data.DataLoader(self.dataset, **loader_params, sampler=None)
self.data_loader_validation = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -398,6 +396,10 @@ def train(self, mini_epoch):
# training loop
self.t_start = time.time()
for bidx, batch in enumerate(dataset_iter):
if cf.data_loading.get("memory_pinning", False):
# pin memory for faster CPU-GPU transfer
batch = batch.pin_memory()

batch.to_device(self.device)

with torch.autocast(
Expand Down Expand Up @@ -512,6 +514,10 @@ def validate(self, mini_epoch, mode_cfg, batch_size):
# print progress bar but only in interactive mode, i.e. when without ddp
with tqdm.tqdm(total=mode_cfg.samples_per_mini_epoch, disable=self.cf.with_ddp) as pbar:
for bidx, batch in enumerate(dataset_val_iter):
if cf.data_loading.get("memory_pinning", False):
# pin memory for faster CPU-GPU transfer
batch = batch.pin_memory()

batch.to_device(self.device)

# evaluate model
Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading