Skip to content

Commit c65cc70

Browse files
Sidharth1743pre-commit-ci[bot]jacobbieker
authored
feat: NVIDIA-style Hierarchical Gradient Checkpointing (#193)
* Added gradient checkpointing * docstrings matched , pytest modifications , benchamrk merged * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move graphcast to models/graphcast/model.py * Remove redundant checkpoint flags test Removed unused test for checkpoint flags. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jacob Prince-Bieker <jacob@bieker.tech>
1 parent 3f5c2e4 commit c65cc70

File tree

10 files changed

+1204
-170
lines changed

10 files changed

+1204
-170
lines changed

graph_weather/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
WrapperImageModel,
88
WrapperMetaModel,
99
)
10+
from .graphcast import GraphCast, GraphCastConfig
1011
from .layers.assimilator_decoder import AssimilatorDecoder
1112
from .layers.assimilator_encoder import AssimilatorEncoder
1213
from .layers.decoder import Decoder
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""GraphCast model with gradient checkpointing."""
2+
3+
from .model import GraphCast, GraphCastConfig
4+
5+
__all__ = ["GraphCast", "GraphCastConfig"]
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
"""GraphCast model with hierarchical gradient checkpointing.
2+
3+
This module provides a complete GraphCast-style weather forecasting model
4+
with NVIDIA-style hierarchical gradient checkpointing for memory-efficient training.
5+
6+
Based on:
7+
- NVIDIA PhysicsNeMo GraphCast implementation
8+
"""
9+
10+
from typing import Optional, Tuple
11+
12+
import torch
13+
from torch import Tensor
14+
from torch.utils.checkpoint import checkpoint
15+
16+
from graph_weather.models.layers.decoder import Decoder
17+
from graph_weather.models.layers.encoder import Encoder
18+
from graph_weather.models.layers.processor import Processor
19+
20+
21+
class GraphCast(torch.nn.Module):
22+
"""GraphCast model with hierarchical gradient checkpointing.
23+
24+
This model combines Encoder, Processor, and Decoder with NVIDIA-style
25+
hierarchical checkpointing controls for flexible memory-compute tradeoffs.
26+
27+
Hierarchical checkpointing methods:
28+
- set_checkpoint_model(flag): Checkpoint entire forward pass
29+
- set_checkpoint_encoder(flag): Checkpoint encoder section
30+
- set_checkpoint_processor(segments): Checkpoint processor with configurable segments
31+
- set_checkpoint_decoder(flag): Checkpoint decoder section
32+
"""
33+
34+
def __init__(
35+
self,
36+
lat_lons: list,
37+
resolution: int = 2,
38+
input_dim: int = 78,
39+
output_dim: int = 78,
40+
hidden_dim: int = 256,
41+
num_processor_blocks: int = 9,
42+
hidden_layers: int = 2,
43+
mlp_norm_type: str = "LayerNorm",
44+
use_checkpointing: bool = False,
45+
efficient_batching: bool = False,
46+
):
47+
"""
48+
Initialize GraphCast model with hierarchical checkpointing support.
49+
50+
Args:
51+
lat_lons: List of (lat, lon) tuples defining the grid points
52+
resolution: H3 resolution level
53+
input_dim: Input feature dimension
54+
output_dim: Output feature dimension
55+
hidden_dim: Hidden dimension for all layers
56+
num_processor_blocks: Number of message passing blocks in processor
57+
hidden_layers: Number of hidden layers in MLPs
58+
mlp_norm_type: Normalization type for MLPs
59+
use_checkpointing: Enable fine-grained checkpointing in all layers
60+
efficient_batching: Use efficient batching (avoid graph replication)
61+
"""
62+
super().__init__()
63+
64+
self.lat_lons = lat_lons
65+
self.input_dim = input_dim
66+
self.output_dim = output_dim
67+
self.efficient_batching = efficient_batching
68+
69+
# Initialize components
70+
self.encoder = Encoder(
71+
lat_lons=lat_lons,
72+
resolution=resolution,
73+
input_dim=input_dim,
74+
output_dim=hidden_dim,
75+
output_edge_dim=hidden_dim,
76+
hidden_dim_processor_node=hidden_dim,
77+
hidden_dim_processor_edge=hidden_dim,
78+
hidden_layers_processor_node=hidden_layers,
79+
hidden_layers_processor_edge=hidden_layers,
80+
mlp_norm_type=mlp_norm_type,
81+
use_checkpointing=use_checkpointing,
82+
efficient_batching=efficient_batching,
83+
)
84+
85+
self.processor = Processor(
86+
input_dim=hidden_dim,
87+
edge_dim=hidden_dim,
88+
num_blocks=num_processor_blocks,
89+
hidden_dim_processor_node=hidden_dim,
90+
hidden_dim_processor_edge=hidden_dim,
91+
hidden_layers_processor_node=hidden_layers,
92+
hidden_layers_processor_edge=hidden_layers,
93+
mlp_norm_type=mlp_norm_type,
94+
use_checkpointing=use_checkpointing,
95+
)
96+
97+
self.decoder = Decoder(
98+
lat_lons=lat_lons,
99+
resolution=resolution,
100+
input_dim=hidden_dim,
101+
output_dim=output_dim,
102+
hidden_dim_processor_node=hidden_dim,
103+
hidden_dim_processor_edge=hidden_dim,
104+
hidden_layers_processor_node=hidden_layers,
105+
hidden_layers_processor_edge=hidden_layers,
106+
mlp_norm_type=mlp_norm_type,
107+
hidden_dim_decoder=hidden_dim,
108+
hidden_layers_decoder=hidden_layers,
109+
use_checkpointing=use_checkpointing,
110+
efficient_batching=efficient_batching,
111+
)
112+
113+
# Hierarchical checkpointing flags (default: use fine-grained checkpointing)
114+
self._checkpoint_model = False
115+
self._checkpoint_encoder = False
116+
self._checkpoint_processor_segments = 0 # 0 = use layer's internal checkpointing
117+
self._checkpoint_decoder = False
118+
119+
def set_checkpoint_model(self, checkpoint_flag: bool):
120+
"""
121+
Checkpoint entire model as a single segment.
122+
123+
When enabled, creates one checkpoint for the entire forward pass.
124+
This provides maximum memory savings but highest recomputation cost.
125+
Disables all other hierarchical checkpointing when enabled.
126+
127+
Args:
128+
checkpoint_flag: If True, checkpoint entire model. If False, use hierarchical checkpointing.
129+
"""
130+
self._checkpoint_model = checkpoint_flag
131+
if checkpoint_flag:
132+
# Disable all fine-grained checkpointing
133+
self._checkpoint_encoder = False
134+
self._checkpoint_processor_segments = 0
135+
self._checkpoint_decoder = False
136+
137+
def set_checkpoint_encoder(self, checkpoint_flag: bool):
138+
"""
139+
Checkpoint encoder section.
140+
141+
Checkpoints the encoder forward pass as a single segment.
142+
Only effective when set_checkpoint_model(False).
143+
144+
Args:
145+
checkpoint_flag: If True, checkpoint encoder section.
146+
"""
147+
self._checkpoint_encoder = checkpoint_flag
148+
149+
def set_checkpoint_processor(self, checkpoint_segments: int):
150+
"""
151+
Checkpoint processor with configurable segments.
152+
153+
Controls how the processor is checkpointed:
154+
- 0: Use processor's internal per-block checkpointing
155+
- -1: Checkpoint entire processor as one segment
156+
- N > 0: Checkpoint every N blocks (not yet implemented)
157+
158+
Only effective when set_checkpoint_model(False).
159+
160+
Args:
161+
checkpoint_segments: Checkpointing strategy (0, -1, or positive integer).
162+
"""
163+
self._checkpoint_processor_segments = checkpoint_segments
164+
165+
def set_checkpoint_decoder(self, checkpoint_flag: bool):
166+
"""
167+
Checkpoint decoder section.
168+
169+
Checkpoints the decoder forward pass as a single segment.
170+
Only effective when set_checkpoint_model(False).
171+
172+
Args:
173+
checkpoint_flag: If True, checkpoint decoder section.
174+
"""
175+
self._checkpoint_decoder = checkpoint_flag
176+
177+
def _encoder_forward(self, features: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
178+
"""
179+
Encoder forward pass (for checkpointing).
180+
"""
181+
return self.encoder(features)
182+
183+
def _processor_forward(
184+
self,
185+
x: Tensor,
186+
edge_index: Tensor,
187+
edge_attr: Tensor,
188+
batch_size: Optional[int] = None,
189+
) -> Tensor:
190+
"""
191+
Processor forward pass (for checkpointing).
192+
"""
193+
return self.processor(
194+
x,
195+
edge_index,
196+
edge_attr,
197+
batch_size=batch_size,
198+
efficient_batching=self.efficient_batching,
199+
)
200+
201+
def _decoder_forward(
202+
self,
203+
processed_features: Tensor,
204+
original_features: Tensor,
205+
batch_size: int,
206+
) -> Tensor:
207+
"""
208+
Decoder forward pass (for checkpointing).
209+
"""
210+
return self.decoder(processed_features, original_features, batch_size)
211+
212+
def _custom_forward(self, features: Tensor) -> Tensor:
213+
"""
214+
Forward pass with hierarchical checkpointing.
215+
"""
216+
batch_size = features.shape[0]
217+
218+
# Encoder
219+
if self._checkpoint_encoder:
220+
latent_features, edge_index, edge_attr = checkpoint(
221+
self._encoder_forward,
222+
features,
223+
use_reentrant=False,
224+
preserve_rng_state=False,
225+
)
226+
else:
227+
latent_features, edge_index, edge_attr = self.encoder(features)
228+
229+
# Processor
230+
if self._checkpoint_processor_segments == -1:
231+
# Checkpoint entire processor as one block
232+
processed_features = checkpoint(
233+
self._processor_forward,
234+
latent_features,
235+
edge_index,
236+
edge_attr,
237+
batch_size if self.efficient_batching else None,
238+
use_reentrant=False,
239+
preserve_rng_state=False,
240+
)
241+
else:
242+
# Use processor's internal checkpointing (controlled by use_checkpointing)
243+
processed_features = self.processor(
244+
latent_features,
245+
edge_index,
246+
edge_attr,
247+
batch_size=batch_size,
248+
efficient_batching=self.efficient_batching,
249+
)
250+
251+
# Decoder
252+
if self._checkpoint_decoder:
253+
output = checkpoint(
254+
self._decoder_forward,
255+
processed_features,
256+
features,
257+
batch_size,
258+
use_reentrant=False,
259+
preserve_rng_state=False,
260+
)
261+
else:
262+
output = self.decoder(processed_features, features, batch_size)
263+
264+
return output
265+
266+
def forward(self, features: Tensor) -> Tensor:
267+
"""Forward pass through GraphCast model.
268+
269+
Args:
270+
features: Input features of shape [batch_size, num_points, input_dim]
271+
272+
Returns:
273+
Output predictions of shape [batch_size, num_points, output_dim]
274+
"""
275+
if self._checkpoint_model:
276+
# Checkpoint entire model as one segment
277+
return checkpoint(
278+
self._custom_forward,
279+
features,
280+
use_reentrant=False,
281+
preserve_rng_state=False,
282+
)
283+
else:
284+
# Use hierarchical checkpointing
285+
return self._custom_forward(features)
286+
287+
288+
class GraphCastConfig:
289+
"""Configuration helper for GraphCast checkpointing strategies.
290+
291+
Provides pre-defined checkpointing strategies for different use cases.
292+
"""
293+
294+
@staticmethod
295+
def no_checkpointing(model: GraphCast):
296+
"""
297+
Disable all checkpointing (maximum speed, maximum memory).
298+
"""
299+
model.set_checkpoint_model(False)
300+
model.set_checkpoint_encoder(False)
301+
model.set_checkpoint_processor(0)
302+
model.set_checkpoint_decoder(False)
303+
304+
@staticmethod
305+
def full_checkpointing(model: GraphCast):
306+
"""
307+
Checkpoint entire model (maximum memory savings, slowest).
308+
"""
309+
model.set_checkpoint_model(True)
310+
311+
@staticmethod
312+
def balanced_checkpointing(model: GraphCast):
313+
"""
314+
Balanced strategy (good memory savings, moderate speed).
315+
"""
316+
model.set_checkpoint_model(False)
317+
model.set_checkpoint_encoder(True)
318+
model.set_checkpoint_processor(-1)
319+
model.set_checkpoint_decoder(True)
320+
321+
@staticmethod
322+
def processor_only_checkpointing(model: GraphCast):
323+
"""
324+
Checkpoint only processor (targets main memory bottleneck).
325+
"""
326+
model.set_checkpoint_model(False)
327+
model.set_checkpoint_encoder(False)
328+
model.set_checkpoint_processor(-1)
329+
model.set_checkpoint_decoder(False)
330+
331+
@staticmethod
332+
def fine_grained_checkpointing(model: GraphCast):
333+
"""
334+
Fine-grained per-layer checkpointing (best memory savings).
335+
336+
This checkpoints each individual MLP and processor block separately.
337+
Provides the best memory savings with moderate recomputation cost.
338+
Note: Model must be created with use_checkpointing=True.
339+
"""
340+
# Fine-grained is enabled via use_checkpointing=True in __init__
341+
# This just disables hierarchical checkpointing
342+
model.set_checkpoint_model(False)
343+
model.set_checkpoint_encoder(False)
344+
model.set_checkpoint_processor(0)
345+
model.set_checkpoint_decoder(False)

graph_weather/models/layers/assimilator_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(
117117
hidden_layers_node=hidden_layers_processor_node,
118118
hidden_layers_edge=hidden_layers_processor_edge,
119119
norm_type=mlp_norm_type,
120+
use_checkpointing=self.use_checkpointing,
120121
)
121122
self.node_decoder = MLP(
122123
input_dim,

graph_weather/models/layers/encoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
hidden_layers_processor_node,
148148
hidden_layers_processor_edge,
149149
mlp_norm_type,
150+
use_checkpointing=self.use_checkpointing,
150151
)
151152

152153
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)