88# nor does it submit to any jurisdiction.
99
1010
11+ import logging
1112from collections .abc import Iterable
1213
1314import networkx as nx
1415import numpy as np
1516import trimesh
1617from sklearn .neighbors import BallTree
17- import logging
1818
1919from anemoi .graphs .generate .masks import KNNAreaMaskBuilder
2020from anemoi .graphs .generate .transforms import cartesian_to_latlon_rad
2121from anemoi .graphs .generate .utils import get_coordinates_ordering
2222
2323LOGGER = logging .getLogger (__name__ )
2424
25+
2526def create_tri_nodes (
2627 resolution : int , area_mask_builder : KNNAreaMaskBuilder | None = None
2728) -> tuple [nx .DiGraph , np .ndarray , list [int ]]:
@@ -135,21 +136,22 @@ def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.nd
135136 assert graph .number_of_nodes () == len (node_ordering ), "The number of nodes must be the same."
136137 return graph
137138
139+
138140def add_edges_hop_1 (nodes_coords_rad , resolutions : list [int ], node_ordering : list [int ]) -> np .ndarray :
139141 """Adds edges for x_hops = 1 relying on trimesh only."""
140-
142+
141143 hop_1_edges = []
142144 for subdivisions in resolutions :
143145 sphere = trimesh .creation .icosphere (subdivisions = subdivisions , radius = 1.0 )
144146 coords_rad = cartesian_to_latlon_rad (sphere .vertices )
145147 LOGGER .debug ("Adding %d 1-hop edges for resolution %d" , sphere .edges .shape [0 ], subdivisions )
146148 hop_1_edges .append (sphere .edges )
147149 # Check that the node coordinates from the node storage match those were generated
148- diff = np .sum (np .abs (nodes_coords_rad .numpy ()[:coords_rad .shape [0 ],:] - coords_rad [node_ordering ]))
150+ diff = np .sum (np .abs (nodes_coords_rad .numpy ()[: coords_rad .shape [0 ], :] - coords_rad [node_ordering ]))
149151 assert diff == 0 , "Node coordinates do not match coordinates generated by trimesh."
150152
151153 # Concatenate all edges from different resolutions and transpose to get shape (2, num_edges)
152- multiscale_edges = np .transpose (np .concatenate (hop_1_edges , axis = 0 ), (1 ,0 ))
154+ multiscale_edges = np .transpose (np .concatenate (hop_1_edges , axis = 0 ), (1 , 0 ))
153155
154156 # Map the edges to the node ordering
155157 inverse_ordering = np .empty_like (node_ordering )
@@ -159,6 +161,7 @@ def add_edges_hop_1(nodes_coords_rad, resolutions: list[int], node_ordering: lis
159161 LOGGER .debug ("multiscale_edges_shape: %s" , multiscale_edges .shape )
160162 return multiscale_edges
161163
164+
162165def add_edges_to_nx_graph (
163166 graph : nx .DiGraph ,
164167 resolutions : list [int ],
@@ -209,7 +212,7 @@ def add_edges_to_nx_graph(
209212 _ , vertex_mapping_index = tree .query (r_vertices_rad , k = 1 )
210213 neighbour_pairs = create_node_neighbours_list (graph , node_neighbours , vertex_mapping_index )
211214 LOGGER .debug ("Adding %d edges for resolution %d" , len (neighbour_pairs ), resolution )
212- #LOGGER.debug("Sample edges: %s", neighbour_pairs)
215+ # LOGGER.debug("Sample edges: %s", neighbour_pairs)
213216 graph .add_edges_from (neighbour_pairs )
214217 return graph
215218
0 commit comments