Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
970c4dc
add pin mem to IOReaderData
Dec 24, 2025
5c566df
add pin mem to sample & modelbatch class
Dec 24, 2025
e85309d
add pin mem to stream data
Dec 24, 2025
ac3b089
add pin mem to training loop
Dec 24, 2025
c3fc9a7
run /scripts/actions.sh lint
Dec 29, 2025
7ac3b3e
run ./scripts/actions.sh unit-test
Dec 29, 2025
a65f561
ignore check torch import in package
Dec 29, 2025
98f4e0b
move pinning to MultiStreamDataSampler
Dec 30, 2025
bc80b26
add _pin_tensor & _pin_tensor_list helper func
Dec 30, 2025
8f98482
ruff the code
Dec 30, 2025
ea8f16c
move back pin mem. to train loop
Dec 30, 2025
61433eb
Remove the ignore-import-error rule and revert to the state before th…
Dec 30, 2025
48c51e3
create protocol for pinnable obj
Dec 30, 2025
dc40a2f
remove pin_mem from IOReaderData class
Dec 30, 2025
36c4b9c
add pin_memory to Trainer.validate
Dec 30, 2025
ebec481
remove pin_memory from loader_params
Dec 30, 2025
62c4e02
Rever export/export_inference.py to state before c3fc9a78
Dec 30, 2025
6a22234
change name
Jan 6, 2026
3796bc8
revise Pinnable class description
Jan 6, 2026
e29160a
Merge branch 'ecmwf:develop' into javad/dev/manual-mem-pinning-1399
javak87 Jan 7, 2026
7fe5b44
add memory_pinning in config, train & va loop
Jan 13, 2026
20944f3
Merge branch 'develop' into javad/dev/manual-mem-pinning-1399
javak87 Jan 14, 2026
08078e8
use getattr to avoid CICD warning
Jan 14, 2026
bd57cf4
use setattr to avoid CICD warning
Jan 14, 2026
503d742
disable pylint for self.source_tokens_lens
Jan 14, 2026
71461b6
Merge branch 'develop' into javad/dev/manual-mem-pinning-1399
clessig Jan 14, 2026
a31d6ea
changes based on #1615
Jan 16, 2026
7a98a08
Merge branch 'javad/dev/manual-mem-pinning-1399' of https://github.co…
Jan 16, 2026
039121b
Merge branch 'develop' into javad/dev/manual-mem-pinning-1399
javak87 Jan 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion packages/common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ missing-attribute = false
no-matching-overload = false
bad-context-manager = false


[tool.pyrefly.ignores]
import-error = [
"torch",
]


# The linting configuration
Expand Down
11 changes: 11 additions & 0 deletions packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dask.array as da
import numpy as np
import torch
import xarray as xr
import zarr
from numpy import datetime64
Expand Down Expand Up @@ -109,6 +110,16 @@ def is_empty(self):
"""
return len(self.data) == 0

def pin_memory(self):
"""Pin all tensors in IOReaderData"""
if hasattr(self, "coords") and isinstance(self.coords, torch.Tensor):
self.coords = self.coords.pin_memory()
if hasattr(self, "data") and isinstance(self.data, torch.Tensor):
self.data = self.data.pin_memory()
if hasattr(self, "geoinfos") and isinstance(self.geoinfos, torch.Tensor):
self.geoinfos = self.geoinfos.pin_memory()
return self

@classmethod
def create(cls, other: typing.Any) -> "IOReaderData":
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def parse_args(args: list) -> argparse.Namespace:
type=str,
help="Grid type to include in the output filename (i.e. 'O96/N320')",
required=False,
default="O96",
default="O96",
dest="quaver_template_grid_type",
)

Expand Down
36 changes: 36 additions & 0 deletions src/weathergen/datasets/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ class Sample:
# keys: stream_name, values: StreamData
streams_data: dict[str, StreamData | None]

def pin_memory(self):
"""Pin all tensors in this Sample to CPU pinned memory"""

# Pin StreamData objects in streams_data dict
if hasattr(self, "streams_data") and isinstance(self.streams_data, dict):
for _stream_name, stream_data in self.streams_data.items():
if stream_data is not None and hasattr(stream_data, "pin_memory"):
stream_data.pin_memory()

# Pin tensors in meta_info
if hasattr(self, "meta_info") and isinstance(self.meta_info, dict):
for _key, meta_data in self.meta_info.items():
if isinstance(meta_data, SampleMetaData):
# Pin mask tensor
if meta_data.mask is not None and isinstance(meta_data.mask, torch.Tensor):
meta_data.mask = meta_data.mask.pin_memory()

return self

def __init__(self, streams: dict) -> None:
# TODO: can we pass this right away?
self.meta_info = {}
Expand Down Expand Up @@ -124,6 +143,23 @@ def __init__(self, streams, num_source_samples: int, num_target_samples: int) ->
self.source2target_matching_idxs = np.full(num_source_samples, -1, dtype=np.int32)
self.target2source_matching_idxs = [[] for _ in range(num_target_samples)]

def pin_memory(self):
"""Pin all tensors in this batch to CPU pinned memory"""

# Pin all source samples
for sample in self.source_samples:
sample.pin_memory()

# Pin all target samples
for sample in self.target_samples:
sample.pin_memory()

# Pin source_tokens_lens
if isinstance(self.source_tokens_lens, torch.Tensor):
self.source_tokens_lens = self.source_tokens_lens.pin_memory()

return self

def to_device(self, device): # -> ModelBatch
for sample in self.source_samples:
sample.to_device(device)
Expand Down
3 changes: 3 additions & 0 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,9 @@ def __iter__(self) -> ModelBatch:

batch = self._get_batch(idx, forecast_dt)

# pin memory for faster CPU-GPU transfer
batch = batch.pin_memory()

# skip completely empty batch item or when all targets are empty -> no grad
if not batch.is_empty():
break
Expand Down
57 changes: 57 additions & 0 deletions src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,63 @@ def __init__(self, idx: int, input_steps: int, forecast_steps: int, healpix_cell
self.source_idxs_embed = [torch.tensor([]) for _ in range(self.input_steps)]
self.source_idxs_embed_pe = [torch.tensor([]) for _ in range(self.input_steps)]

def pin_memory(self):
"""Pin all tensors in this StreamData object to CPU pinned memory"""

# Pin target tensors
self.target_coords = [
t.pin_memory() if isinstance(t, torch.Tensor) and t.numel() > 0 else t
for t in self.target_coords
]
self.target_coords_lens = [
t.pin_memory() if isinstance(t, torch.Tensor) and t.numel() > 0 else t
for t in self.target_coords_lens
]
self.target_tokens = [
t.pin_memory() if isinstance(t, torch.Tensor) and t.numel() > 0 else t
for t in self.target_tokens
]
self.target_tokens_lens = [
t.pin_memory() if isinstance(t, torch.Tensor) and t.numel() > 0 else t
for t in self.target_tokens_lens
]
self.idxs_inv = [
t.pin_memory() if isinstance(t, torch.Tensor) and t.numel() > 0 else t
for t in self.idxs_inv
]

# Pin target_coords_raw (list of tensors)
self.target_coords_raw = [
t.pin_memory() if isinstance(t, torch.Tensor) and t.numel() > 0 else t
for t in self.target_coords_raw
]

# Pin source tensors
self.source_tokens_cells = [
s.pin_memory() if s is not None and isinstance(s, torch.Tensor) else s
for s in self.source_tokens_cells
]
self.source_tokens_lens = [
s.pin_memory() if isinstance(s, torch.Tensor) and s.numel() > 0 else s
for s in self.source_tokens_lens
]
self.source_idxs_embed = [
s.pin_memory() if isinstance(s, torch.Tensor) and s.numel() > 0 else s
for s in self.source_idxs_embed
]
self.source_idxs_embed_pe = [
s.pin_memory() if isinstance(s, torch.Tensor) and s.numel() > 0 else s
for s in self.source_idxs_embed_pe
]

# Pin source_raw (list of IOReaderData objects)
if hasattr(self, "source_raw"):
for raw_data in self.source_raw:
if raw_data is not None and hasattr(raw_data, "pin_memory"):
raw_data.pin_memory()

return self

def to_device(self, device: str) -> None:
"""
Move data to GPU
Expand Down
20 changes: 18 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading