Skip to content

Commit c5b1864

Browse files
Optimization: Efficient Sequential Batching (Solves #47, Supersedes #186) (#187)
* feat: Add efficient batching to reduce memory usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: Add efficient batching to reduce memory usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Dependencies resolved , unncessaru docs removed , moved to pytest , benchmark is under scripts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: Add efficient batching to reduce memory usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert to nnja_ai import * fix: Revert nnja_ai import in data module * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8ae257c commit c5b1864

File tree

6 files changed

+485
-67
lines changed

6 files changed

+485
-67
lines changed

graph_weather/models/layers/assimilator_decoder.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
hidden_dim_decoder: int = 128,
4242
hidden_layers_decoder: int = 2,
4343
use_checkpointing: bool = False,
44+
efficient_batching: bool = False,
4445
):
4546
"""
4647
Decoder from latent graph to lat/lon graph for assimilation of observation
@@ -63,6 +64,7 @@ def __init__(
6364
"""
6465
super().__init__()
6566
self.use_checkpointing = use_checkpointing
67+
self.efficient_batching = efficient_batching
6668
self.num_latlons = len(lat_lons)
6769
self.base_h3_grid = sorted(list(h3.uncompact_cells(h3.get_res0_cells(), resolution)))
6870
self.num_h3 = len(self.base_h3_grid)
@@ -137,28 +139,61 @@ def forward(self, processor_features: torch.Tensor, batch_size: int) -> torch.Te
137139
Updated features for model
138140
"""
139141
self.graph = self.graph.to(processor_features.device)
140-
edge_attr = self.edge_encoder(self.graph.edge_attr) # Update attributes based on distance
141-
edge_attr = einops.repeat(edge_attr, "e f -> (repeat e) f", repeat=batch_size)
142-
143-
edge_index = torch.cat(
144-
[
145-
self.graph.edge_index + i * torch.max(self.graph.edge_index) + i
146-
for i in range(batch_size)
147-
],
148-
dim=1,
149-
)
150-
151-
# Readd nodes to match graph node number
152142
self.latlon_nodes = self.latlon_nodes.to(processor_features.device)
153-
features = einops.rearrange(processor_features, "(b n) f -> b n f", b=batch_size)
154-
features = torch.cat(
155-
[features, einops.repeat(self.latlon_nodes, "n f -> b n f", b=batch_size)], dim=1
156-
)
157-
features = einops.rearrange(features, "b n f -> (b n) f")
158-
159-
out, _ = self.graph_processor(features, edge_index, edge_attr) # Message Passing
160-
# Remove the h3 nodes now, only want the latlon ones
161-
out = self.node_decoder(out) # Decode to 78 from 256
162-
out = einops.rearrange(out, "(b n) f -> b n f", b=batch_size)
163-
test, out = torch.split(out, [self.num_h3, self.num_latlons], dim=1)
164-
return out
143+
144+
if self.efficient_batching:
145+
# Efficient batching: process batches separately to avoid graph replication
146+
edge_attr = self.edge_encoder(self.graph.edge_attr) # Encode once
147+
148+
# Split processor features by batch
149+
proc_features_batched = einops.rearrange(
150+
processor_features, "(b n) f -> b n f", b=batch_size
151+
)
152+
153+
batch_outputs = []
154+
for i in range(batch_size):
155+
# Get features for this batch
156+
feat_i = torch.cat(
157+
[proc_features_batched[i], self.latlon_nodes], dim=0
158+
) # [num_h3 + num_latlon, F]
159+
160+
# Message passing with single graph (no replication)
161+
out_i, _ = self.graph_processor(feat_i, self.graph.edge_index, edge_attr)
162+
163+
# Decode and extract latlon nodes
164+
out_i = self.node_decoder(out_i)
165+
out_i = out_i[self.num_h3 :] # Keep only latlon nodes
166+
167+
batch_outputs.append(out_i)
168+
169+
# Stack outputs
170+
out = torch.stack(batch_outputs, dim=0) # [B, num_latlon, F]
171+
return out
172+
else:
173+
# Original batching implementation
174+
edge_attr = self.edge_encoder(
175+
self.graph.edge_attr
176+
) # Update attributes based on distance
177+
edge_attr = einops.repeat(edge_attr, "e f -> (repeat e) f", repeat=batch_size)
178+
179+
edge_index = torch.cat(
180+
[
181+
self.graph.edge_index + i * torch.max(self.graph.edge_index) + i
182+
for i in range(batch_size)
183+
],
184+
dim=1,
185+
)
186+
187+
# Readd nodes to match graph node number
188+
features = einops.rearrange(processor_features, "(b n) f -> b n f", b=batch_size)
189+
features = torch.cat(
190+
[features, einops.repeat(self.latlon_nodes, "n f -> b n f", b=batch_size)], dim=1
191+
)
192+
features = einops.rearrange(features, "b n f -> (b n) f")
193+
194+
out, _ = self.graph_processor(features, edge_index, edge_attr) # Message Passing
195+
# Remove the h3 nodes now, only want the latlon ones
196+
out = self.node_decoder(out) # Decode to 78 from 256
197+
out = einops.rearrange(out, "(b n) f -> b n f", b=batch_size)
198+
test, out = torch.split(out, [self.num_h3, self.num_latlons], dim=1)
199+
return out

graph_weather/models/layers/decoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
hidden_dim_decoder: int = 128,
3838
hidden_layers_decoder: int = 2,
3939
use_checkpointing: bool = False,
40+
efficient_batching: bool = False,
4041
):
4142
"""
4243
Decoder from latent graph to lat/lon graph
@@ -56,6 +57,7 @@ def __init__(
5657
mlp_norm_type: Type of norm for the MLPs
5758
one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
5859
use_checkpointing: Whether to use gradient checkpointing or not
60+
efficient_batching: Whether to use efficient batching (avoids graph replication)
5961
"""
6062
super().__init__(
6163
lat_lons,
@@ -71,6 +73,7 @@ def __init__(
7173
hidden_dim_decoder,
7274
hidden_layers_decoder,
7375
use_checkpointing,
76+
efficient_batching,
7477
)
7578

7679
def forward(

graph_weather/models/layers/encoder.py

Lines changed: 74 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
hidden_layers_processor_edge=2,
5050
mlp_norm_type="LayerNorm",
5151
use_checkpointing: bool = False,
52+
efficient_batching: bool = False,
5253
):
5354
"""
5455
Encode the lat/lon data inot the isohedron graph
@@ -69,6 +70,7 @@ def __init__(
6970
"""
7071
super().__init__()
7172
self.use_checkpointing = use_checkpointing
73+
self.efficient_batching = efficient_batching
7274
self.output_dim = output_dim
7375
self.num_latlons = len(lat_lons)
7476
self.base_h3_grid = sorted(list(h3.uncompact_cells(h3.get_res0_cells(), resolution)))
@@ -161,46 +163,82 @@ def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, t
161163
self.h3_nodes = self.h3_nodes.to(features.device)
162164
self.graph = self.graph.to(features.device)
163165
self.latent_graph = self.latent_graph.to(features.device)
164-
features = torch.cat(
165-
[features, einops.repeat(self.h3_nodes, "n f -> b n f", b=batch_size)],
166-
dim=1,
167-
)
168-
# Cat with the h3 nodes to have correct amount of nodes, and in right order
169-
features = einops.rearrange(features, "b n f -> (b n) f")
170-
out = self.node_encoder(features) # Encode to 256 from 78
171-
edge_attr = self.edge_encoder(self.graph.edge_attr) # Update attributes based on distance
172-
# Copy attributes batch times
173-
edge_attr = einops.repeat(edge_attr, "e f -> (repeat e) f", repeat=batch_size)
174-
# Expand edge index correct number of times while adding the proper number to the edge index
175-
edge_index = torch.cat(
176-
[
177-
self.graph.edge_index + i * torch.max(self.graph.edge_index) + i
178-
for i in range(batch_size)
179-
],
180-
dim=1,
181-
)
182-
out, _ = self.graph_processor(out, edge_index, edge_attr) # Message Passing
183-
# Remove the extra nodes (lat/lon) from the output
184-
out = einops.rearrange(out, "(b n) f -> b n f", b=batch_size)
185-
_, out = torch.split(out, [self.num_latlons, self.h3_nodes.shape[0]], dim=1)
186-
out = einops.rearrange(out, "b n f -> (b n) f")
187-
return (
188-
out,
189-
torch.cat(
166+
167+
if self.efficient_batching:
168+
# Efficient batching: process batches separately to avoid graph replication
169+
batch_outputs = []
170+
for i in range(batch_size):
171+
# Process single batch item
172+
feat_i = torch.cat([features[i : i + 1], self.h3_nodes.unsqueeze(0)], dim=1)
173+
feat_i = feat_i.squeeze(0) # [N, F]
174+
175+
# Encode nodes
176+
out_i = self.node_encoder(feat_i)
177+
178+
# Encode edges (no replication needed)
179+
edge_attr_i = self.edge_encoder(self.graph.edge_attr)
180+
181+
# Message passing with single graph
182+
out_i, _ = self.graph_processor(out_i, self.graph.edge_index, edge_attr_i)
183+
184+
# Extract H3 nodes only
185+
out_i = out_i[self.num_latlons :] # Keep only H3 nodes
186+
batch_outputs.append(out_i)
187+
188+
# Stack outputs
189+
out = torch.cat(batch_outputs, dim=0) # [B*num_h3, F]
190+
191+
# Return with SHARED latent graph (NO replication at all!)
192+
latent_edge_attr = self.latent_edge_encoder(self.latent_graph.edge_attr)
193+
194+
# Return the single shared graph - no batching overhead
195+
return (out, self.latent_graph.edge_index, latent_edge_attr)
196+
else:
197+
# Original batching implementation
198+
features = torch.cat(
199+
[features, einops.repeat(self.h3_nodes, "n f -> b n f", b=batch_size)],
200+
dim=1,
201+
)
202+
# Cat with the h3 nodes to have correct amount of nodes, and in right order
203+
features = einops.rearrange(features, "b n f -> (b n) f")
204+
out = self.node_encoder(features) # Encode to 256 from 78
205+
edge_attr = self.edge_encoder(
206+
self.graph.edge_attr
207+
) # Update attributes based on distance
208+
# Copy attributes batch times
209+
edge_attr = einops.repeat(edge_attr, "e f -> (repeat e) f", repeat=batch_size)
210+
# Expand edge index correct number of times while adding the proper number to the edge index
211+
edge_index = torch.cat(
190212
[
191-
self.latent_graph.edge_index + i * torch.max(self.latent_graph.edge_index) + i
213+
self.graph.edge_index + i * torch.max(self.graph.edge_index) + i
192214
for i in range(batch_size)
193215
],
194216
dim=1,
195-
),
196-
self.latent_edge_encoder(
197-
einops.repeat(
198-
self.latent_graph.edge_attr,
199-
"e f -> (repeat e) f",
200-
repeat=batch_size,
201-
)
202-
),
203-
) # New graph
217+
)
218+
out, _ = self.graph_processor(out, edge_index, edge_attr) # Message Passing
219+
# Remove the extra nodes (lat/lon) from the output
220+
out = einops.rearrange(out, "(b n) f -> b n f", b=batch_size)
221+
_, out = torch.split(out, [self.num_latlons, self.h3_nodes.shape[0]], dim=1)
222+
out = einops.rearrange(out, "b n f -> (b n) f")
223+
return (
224+
out,
225+
torch.cat(
226+
[
227+
self.latent_graph.edge_index
228+
+ i * torch.max(self.latent_graph.edge_index)
229+
+ i
230+
for i in range(batch_size)
231+
],
232+
dim=1,
233+
),
234+
self.latent_edge_encoder(
235+
einops.repeat(
236+
self.latent_graph.edge_attr,
237+
"e f -> (repeat e) f",
238+
repeat=batch_size,
239+
)
240+
),
241+
) # New graph
204242

205243
def create_latent_graph(self) -> Data:
206244
"""

graph_weather/models/layers/processor.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,49 @@ def __init__(
6363
if self.use_thermalizer:
6464
self.thermalizer = ThermalizerLayer(input_dim)
6565

66-
def forward(self, x: torch.Tensor, edge_index, edge_attr, t: int = 0) -> torch.Tensor:
66+
def forward(
67+
self,
68+
x: torch.Tensor,
69+
edge_index,
70+
edge_attr,
71+
t: int = 0,
72+
batch_size: int = None,
73+
efficient_batching: bool = False,
74+
) -> torch.Tensor:
6775
"""
6876
Adds features to the encoding graph
6977
7078
Args:
71-
x: Torch tensor containing node features
79+
x: Torch tensor containing node features [B*N, F] or [N, F]
7280
edge_index: Connectivity of graph, of shape [2, Num edges] in COO format
73-
edge_attr: Edge attribues in [Num edges, Features] shape
81+
edge_attr: Edge attributes in [Num edges, Features] shape
7482
t: Timestep for the thermalizer
83+
batch_size: Batch size (required when efficient_batching=True)
84+
efficient_batching: If True, process batches separately with shared graph
7585
7686
Returns:
7787
torch Tensor containing the values of the nodes of the graph
7888
"""
79-
out, _ = self.graph_processor(x, edge_index, edge_attr)
80-
if self.use_thermalizer:
81-
out = self.thermalizer(out, t)
82-
return out
89+
if efficient_batching and batch_size is not None and batch_size > 1:
90+
# Efficient batching: process each batch separately with shared graph
91+
# x is [B*N, F], split into B batches of [N, F]
92+
num_nodes_per_batch = x.shape[0] // batch_size
93+
x_batched = x.view(batch_size, num_nodes_per_batch, -1)
94+
95+
batch_outputs = []
96+
for i in range(batch_size):
97+
# Process single batch with shared graph
98+
out_i, _ = self.graph_processor(x_batched[i], edge_index, edge_attr)
99+
if self.use_thermalizer:
100+
out_i = self.thermalizer(out_i, t)
101+
batch_outputs.append(out_i)
102+
103+
# Concatenate outputs back to [B*N, F] format
104+
out = torch.cat(batch_outputs, dim=0)
105+
return out
106+
else:
107+
# Original batching: process all at once with batched graph
108+
out, _ = self.graph_processor(x, edge_index, edge_attr)
109+
if self.use_thermalizer:
110+
out = self.thermalizer(out, t)
111+
return out

0 commit comments

Comments
 (0)