Skip to content

Commit 5bc68d7

Browse files
committed
introduce new TriNodesEdgebuilder-strategy
1 parent be9ff8b commit 5bc68d7

File tree

3 files changed

+62
-14
lines changed

3 files changed

+62
-14
lines changed

graphs/src/anemoi/graphs/edges/builders/multi_scale.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class MultiScaleEdges(BaseEdgeBuilder):
3838
Defines the refinement levels at which edges are computed. If an integer is provided, edges are computed for all
3939
levels up to and including that level. For instance, `scale_resolutions=4` includes edges at levels 1 through 4,
4040
whereas `scale_resolutions=[4]` only includes edges at level 4.
41+
new_method: bool, optional
42+
Use the new edge building strategy for x_hops=1. Default is False
4143
4244
Methods
4345
-------
@@ -59,6 +61,7 @@ def __init__(
5961
target_name: str,
6062
x_hops: int,
6163
scale_resolutions: int | list[int] | None = None,
64+
new_method: bool = False,
6265
**kwargs,
6366
):
6467
super().__init__(source_name, target_name)
@@ -74,6 +77,7 @@ def __init__(
7477
scale_resolutions is None or min(scale_resolutions) > 0
7578
), "The scale_resolutions argument only supports positive integers."
7679
self.scale_resolutions = scale_resolutions
80+
self.new_method = new_method
7781

7882
@staticmethod
7983
def get_edge_builder_class(node_type: str) -> type[BaseIcosahedronEdgeStrategy]:
@@ -102,10 +106,14 @@ def compute_edge_index(self, source_nodes: NodeStorage, _target_nodes: NodeStora
102106
)
103107

104108
# Add edges
105-
source_nodes = edge_builder_cls().add_edges(source_nodes, self.x_hops, scale_resolutions=scale_resolutions)
106-
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
107-
108-
# Get source & target indices of the edges
109-
edge_index = np.stack([adjmat.col, adjmat.row], axis=0)
109+
source_nodes = edge_builder_cls().add_edges(source_nodes, self.x_hops, scale_resolutions=scale_resolutions, new_method=self.new_method)
110+
if self.new_method:
111+
# If the new method is used, the edges are already computed and stored in the node storage
112+
edge_index = source_nodes["_multiscale_edges"]
113+
else:
114+
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
115+
# Get source & target indices of the edges
116+
edge_index = np.stack([adjmat.col, adjmat.row], axis=0)
117+
110118

111119
return torch.from_numpy(edge_index).to(torch.int32)

graphs/src/anemoi/graphs/generate/multi_scale_edges.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from abc import abstractmethod
1313

1414
from torch_geometric.data.storage import NodeStorage
15-
15+
import logging
16+
LOGGER = logging.getLogger(__name__)
1617

1718
class BaseIcosahedronEdgeStrategy(ABC):
1819
"""Abstract base class for different edge-building strategies."""
@@ -24,16 +25,28 @@ def add_edges(self, nodes: NodeStorage, x_hops: int, scale_resolutions: list[int
2425
class TriNodesEdgeBuilder(BaseIcosahedronEdgeStrategy):
2526
"""Edge builder for TriNodes and LimitedAreaTriNodes."""
2627

27-
def add_edges(self, nodes: NodeStorage, x_hops: int, scale_resolutions: list[int]) -> NodeStorage:
28+
def add_edges(self, nodes: NodeStorage, x_hops: int, scale_resolutions: list[int], new_method: bool = False ) -> NodeStorage:
2829
from anemoi.graphs.generate import tri_icosahedron
2930

30-
nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph(
31-
nodes["_nx_graph"],
32-
resolutions=scale_resolutions,
33-
x_hops=x_hops,
34-
area_mask_builder=nodes.get("_area_mask_builder", None),
35-
)
36-
return nodes
31+
if new_method:
32+
assert x_hops == 1, "New strategy currently only supports x_hops=1."
33+
LOGGER.info("Using new strategy for x_hops=1 multiscale-edge building.")
34+
# Compute the multiscale edges directly and store them in the node storage
35+
multiscale_edges = tri_icosahedron.add_edges_hop_1(
36+
nodes_coords_rad = nodes["x"],
37+
resolutions=scale_resolutions,
38+
node_ordering=nodes["_node_ordering"],
39+
)
40+
nodes["_multiscale_edges"] = multiscale_edges
41+
else:
42+
LOGGER.info("Using existing strategy for multiscale-edge building.")
43+
nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph(
44+
nodes["_nx_graph"],
45+
resolutions=scale_resolutions,
46+
x_hops=x_hops,
47+
area_mask_builder=nodes.get("_area_mask_builder", None),
48+
)
49+
return nodes
3750

3851

3952
class HexNodesEdgeBuilder(BaseIcosahedronEdgeStrategy):

graphs/src/anemoi/graphs/generate/tri_icosahedron.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
import numpy as np
1515
import trimesh
1616
from sklearn.neighbors import BallTree
17+
import logging
1718

1819
from anemoi.graphs.generate.masks import KNNAreaMaskBuilder
1920
from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad
2021
from anemoi.graphs.generate.utils import get_coordinates_ordering
2122

23+
LOGGER = logging.getLogger(__name__)
2224

2325
def create_tri_nodes(
2426
resolution: int, area_mask_builder: KNNAreaMaskBuilder | None = None
@@ -133,6 +135,29 @@ def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.nd
133135
assert graph.number_of_nodes() == len(node_ordering), "The number of nodes must be the same."
134136
return graph
135137

138+
def add_edges_hop_1(nodes_coords_rad, resolutions: list[int], node_ordering: list[int]) -> np.ndarray:
139+
"""Adds edges for x_hops = 1 relying on trimesh only."""
140+
141+
hop_1_edges = []
142+
for subdivisions in resolutions:
143+
sphere = trimesh.creation.icosphere(subdivisions=subdivisions, radius=1.0)
144+
coords_rad = cartesian_to_latlon_rad(sphere.vertices)
145+
LOGGER.debug("Adding %d 1-hop edges for resolution %d", sphere.edges.shape[0], subdivisions)
146+
hop_1_edges.append(sphere.edges)
147+
# Check that the node coordinates from the node storage match those were generated
148+
diff = np.sum(np.abs(nodes_coords_rad.numpy()[:coords_rad.shape[0],:] - coords_rad[node_ordering]))
149+
assert diff == 0, "Node coordinates do not match coordinates generated by trimesh."
150+
151+
# Concatenate all edges from different resolutions and transpose to get shape (2, num_edges)
152+
multiscale_edges = np.transpose(np.concatenate(hop_1_edges, axis=0), (1,0))
153+
154+
# Map the edges to the node ordering
155+
inverse_ordering = np.empty_like(node_ordering)
156+
inverse_ordering[node_ordering] = np.arange(len(node_ordering))
157+
multiscale_edges = inverse_ordering[multiscale_edges]
158+
159+
LOGGER.debug("multiscale_edges_shape: %s", multiscale_edges.shape)
160+
return multiscale_edges
136161

137162
def add_edges_to_nx_graph(
138163
graph: nx.DiGraph,
@@ -183,6 +208,8 @@ def add_edges_to_nx_graph(
183208

184209
_, vertex_mapping_index = tree.query(r_vertices_rad, k=1)
185210
neighbour_pairs = create_node_neighbours_list(graph, node_neighbours, vertex_mapping_index)
211+
LOGGER.debug("Adding %d edges for resolution %d", len(neighbour_pairs), resolution)
212+
#LOGGER.debug("Sample edges: %s", neighbour_pairs)
186213
graph.add_edges_from(neighbour_pairs)
187214
return graph
188215

0 commit comments

Comments
 (0)