Skip to content

Commit c90e5db

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5bc68d7 commit c90e5db

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

graphs/src/anemoi/graphs/edges/builders/multi_scale.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,15 @@ def compute_edge_index(self, source_nodes: NodeStorage, _target_nodes: NodeStora
106106
)
107107

108108
# Add edges
109-
source_nodes = edge_builder_cls().add_edges(source_nodes, self.x_hops, scale_resolutions=scale_resolutions, new_method=self.new_method)
109+
source_nodes = edge_builder_cls().add_edges(
110+
source_nodes, self.x_hops, scale_resolutions=scale_resolutions, new_method=self.new_method
111+
)
110112
if self.new_method:
111113
# If the new method is used, the edges are already computed and stored in the node storage
112114
edge_index = source_nodes["_multiscale_edges"]
113115
else:
114116
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
115117
# Get source & target indices of the edges
116118
edge_index = np.stack([adjmat.col, adjmat.row], axis=0)
117-
118119

119120
return torch.from_numpy(edge_index).to(torch.int32)

graphs/src/anemoi/graphs/generate/multi_scale_edges.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
# nor does it submit to any jurisdiction.
99

1010

11+
import logging
1112
from abc import ABC
1213
from abc import abstractmethod
1314

1415
from torch_geometric.data.storage import NodeStorage
15-
import logging
16+
1617
LOGGER = logging.getLogger(__name__)
1718

19+
1820
class BaseIcosahedronEdgeStrategy(ABC):
1921
"""Abstract base class for different edge-building strategies."""
2022

@@ -25,15 +27,17 @@ def add_edges(self, nodes: NodeStorage, x_hops: int, scale_resolutions: list[int
2527
class TriNodesEdgeBuilder(BaseIcosahedronEdgeStrategy):
2628
"""Edge builder for TriNodes and LimitedAreaTriNodes."""
2729

28-
def add_edges(self, nodes: NodeStorage, x_hops: int, scale_resolutions: list[int], new_method: bool = False ) -> NodeStorage:
30+
def add_edges(
31+
self, nodes: NodeStorage, x_hops: int, scale_resolutions: list[int], new_method: bool = False
32+
) -> NodeStorage:
2933
from anemoi.graphs.generate import tri_icosahedron
3034

3135
if new_method:
3236
assert x_hops == 1, "New strategy currently only supports x_hops=1."
3337
LOGGER.info("Using new strategy for x_hops=1 multiscale-edge building.")
3438
# Compute the multiscale edges directly and store them in the node storage
3539
multiscale_edges = tri_icosahedron.add_edges_hop_1(
36-
nodes_coords_rad = nodes["x"],
40+
nodes_coords_rad=nodes["x"],
3741
resolutions=scale_resolutions,
3842
node_ordering=nodes["_node_ordering"],
3943
)

graphs/src/anemoi/graphs/generate/tri_icosahedron.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,21 @@
88
# nor does it submit to any jurisdiction.
99

1010

11+
import logging
1112
from collections.abc import Iterable
1213

1314
import networkx as nx
1415
import numpy as np
1516
import trimesh
1617
from sklearn.neighbors import BallTree
17-
import logging
1818

1919
from anemoi.graphs.generate.masks import KNNAreaMaskBuilder
2020
from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad
2121
from anemoi.graphs.generate.utils import get_coordinates_ordering
2222

2323
LOGGER = logging.getLogger(__name__)
2424

25+
2526
def create_tri_nodes(
2627
resolution: int, area_mask_builder: KNNAreaMaskBuilder | None = None
2728
) -> tuple[nx.DiGraph, np.ndarray, list[int]]:
@@ -135,21 +136,22 @@ def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.nd
135136
assert graph.number_of_nodes() == len(node_ordering), "The number of nodes must be the same."
136137
return graph
137138

139+
138140
def add_edges_hop_1(nodes_coords_rad, resolutions: list[int], node_ordering: list[int]) -> np.ndarray:
139141
"""Adds edges for x_hops = 1 relying on trimesh only."""
140-
142+
141143
hop_1_edges = []
142144
for subdivisions in resolutions:
143145
sphere = trimesh.creation.icosphere(subdivisions=subdivisions, radius=1.0)
144146
coords_rad = cartesian_to_latlon_rad(sphere.vertices)
145147
LOGGER.debug("Adding %d 1-hop edges for resolution %d", sphere.edges.shape[0], subdivisions)
146148
hop_1_edges.append(sphere.edges)
147149
# Check that the node coordinates from the node storage match those were generated
148-
diff = np.sum(np.abs(nodes_coords_rad.numpy()[:coords_rad.shape[0],:] - coords_rad[node_ordering]))
150+
diff = np.sum(np.abs(nodes_coords_rad.numpy()[: coords_rad.shape[0], :] - coords_rad[node_ordering]))
149151
assert diff == 0, "Node coordinates do not match coordinates generated by trimesh."
150152

151153
# Concatenate all edges from different resolutions and transpose to get shape (2, num_edges)
152-
multiscale_edges = np.transpose(np.concatenate(hop_1_edges, axis=0), (1,0))
154+
multiscale_edges = np.transpose(np.concatenate(hop_1_edges, axis=0), (1, 0))
153155

154156
# Map the edges to the node ordering
155157
inverse_ordering = np.empty_like(node_ordering)
@@ -159,6 +161,7 @@ def add_edges_hop_1(nodes_coords_rad, resolutions: list[int], node_ordering: lis
159161
LOGGER.debug("multiscale_edges_shape: %s", multiscale_edges.shape)
160162
return multiscale_edges
161163

164+
162165
def add_edges_to_nx_graph(
163166
graph: nx.DiGraph,
164167
resolutions: list[int],
@@ -209,7 +212,7 @@ def add_edges_to_nx_graph(
209212
_, vertex_mapping_index = tree.query(r_vertices_rad, k=1)
210213
neighbour_pairs = create_node_neighbours_list(graph, node_neighbours, vertex_mapping_index)
211214
LOGGER.debug("Adding %d edges for resolution %d", len(neighbour_pairs), resolution)
212-
#LOGGER.debug("Sample edges: %s", neighbour_pairs)
215+
# LOGGER.debug("Sample edges: %s", neighbour_pairs)
213216
graph.add_edges_from(neighbour_pairs)
214217
return graph
215218

0 commit comments

Comments
 (0)