Skip to content

Commit 87d435e

Browse files
committed
support for LimitedAreaTriNodes
1 parent c90e5db commit 87d435e

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

graphs/src/anemoi/graphs/generate/multi_scale_edges.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def add_edges(
4040
nodes_coords_rad=nodes["x"],
4141
resolutions=scale_resolutions,
4242
node_ordering=nodes["_node_ordering"],
43+
area_mask_builder=nodes.get("_area_mask_builder", None),
4344
)
4445
nodes["_multiscale_edges"] = multiscale_edges
4546
else:
@@ -50,7 +51,7 @@ def add_edges(
5051
x_hops=x_hops,
5152
area_mask_builder=nodes.get("_area_mask_builder", None),
5253
)
53-
return nodes
54+
return nodes
5455

5556

5657
class HexNodesEdgeBuilder(BaseIcosahedronEdgeStrategy):

graphs/src/anemoi/graphs/generate/tri_icosahedron.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,12 @@ def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.nd
137137
return graph
138138

139139

140-
def add_edges_hop_1(nodes_coords_rad, resolutions: list[int], node_ordering: list[int]) -> np.ndarray:
140+
def add_edges_hop_1(
141+
nodes_coords_rad,
142+
resolutions: list[int],
143+
node_ordering: list[int],
144+
area_mask_builder: KNNAreaMaskBuilder | None = None,
145+
) -> np.ndarray:
141146
"""Adds edges for x_hops = 1 relying on trimesh only."""
142147

143148
hop_1_edges = []
@@ -154,9 +159,16 @@ def add_edges_hop_1(nodes_coords_rad, resolutions: list[int], node_ordering: lis
154159
multiscale_edges = np.transpose(np.concatenate(hop_1_edges, axis=0), (1, 0))
155160

156161
# Map the edges to the node ordering
157-
inverse_ordering = np.empty_like(node_ordering)
158-
inverse_ordering[node_ordering] = np.arange(len(node_ordering))
159-
multiscale_edges = inverse_ordering[multiscale_edges]
162+
if area_mask_builder is not None:
163+
inverse_ordering = np.full(coords_rad.shape[0], -1, dtype=int)
164+
inverse_ordering[node_ordering] = np.arange(len(node_ordering))
165+
updated_edges = inverse_ordering[multiscale_edges]
166+
valid_edges_mask = np.all(updated_edges >= 0, axis=0)
167+
multiscale_edges = updated_edges[:, valid_edges_mask]
168+
else:
169+
inverse_ordering = np.empty_like(node_ordering)
170+
inverse_ordering[node_ordering] = np.arange(len(node_ordering))
171+
multiscale_edges = inverse_ordering[multiscale_edges]
160172

161173
LOGGER.debug("multiscale_edges_shape: %s", multiscale_edges.shape)
162174
return multiscale_edges

0 commit comments

Comments
 (0)