|
14 | 14 | import numpy as np |
15 | 15 | import trimesh |
16 | 16 | from sklearn.neighbors import BallTree |
| 17 | +import logging |
17 | 18 |
|
18 | 19 | from anemoi.graphs.generate.masks import KNNAreaMaskBuilder |
19 | 20 | from anemoi.graphs.generate.transforms import cartesian_to_latlon_rad |
20 | 21 | from anemoi.graphs.generate.utils import get_coordinates_ordering |
21 | 22 |
|
| 23 | +LOGGER = logging.getLogger(__name__) |
22 | 24 |
|
23 | 25 | def create_tri_nodes( |
24 | 26 | resolution: int, area_mask_builder: KNNAreaMaskBuilder | None = None |
@@ -133,6 +135,29 @@ def create_nx_graph_from_tri_coords(coords_rad: np.ndarray, node_ordering: np.nd |
133 | 135 | assert graph.number_of_nodes() == len(node_ordering), "The number of nodes must be the same." |
134 | 136 | return graph |
135 | 137 |
|
| 138 | +def add_edges_hop_1(nodes_coords_rad, resolutions: list[int], node_ordering: list[int]) -> np.ndarray: |
| 139 | + """Adds edges for x_hops = 1 relying on trimesh only.""" |
| 140 | + |
| 141 | + hop_1_edges = [] |
| 142 | + for subdivisions in resolutions: |
| 143 | + sphere = trimesh.creation.icosphere(subdivisions=subdivisions, radius=1.0) |
| 144 | + coords_rad = cartesian_to_latlon_rad(sphere.vertices) |
| 145 | + LOGGER.debug("Adding %d 1-hop edges for resolution %d", sphere.edges.shape[0], subdivisions) |
| 146 | + hop_1_edges.append(sphere.edges) |
| 147 | + # Check that the node coordinates from the node storage match those were generated |
| 148 | + diff = np.sum(np.abs(nodes_coords_rad.numpy()[:coords_rad.shape[0],:] - coords_rad[node_ordering])) |
| 149 | + assert diff == 0, "Node coordinates do not match coordinates generated by trimesh." |
| 150 | + |
| 151 | + # Concatenate all edges from different resolutions and transpose to get shape (2, num_edges) |
| 152 | + multiscale_edges = np.transpose(np.concatenate(hop_1_edges, axis=0), (1,0)) |
| 153 | + |
| 154 | + # Map the edges to the node ordering |
| 155 | + inverse_ordering = np.empty_like(node_ordering) |
| 156 | + inverse_ordering[node_ordering] = np.arange(len(node_ordering)) |
| 157 | + multiscale_edges = inverse_ordering[multiscale_edges] |
| 158 | + |
| 159 | + LOGGER.debug("multiscale_edges_shape: %s", multiscale_edges.shape) |
| 160 | + return multiscale_edges |
136 | 161 |
|
137 | 162 | def add_edges_to_nx_graph( |
138 | 163 | graph: nx.DiGraph, |
@@ -183,6 +208,8 @@ def add_edges_to_nx_graph( |
183 | 208 |
|
184 | 209 | _, vertex_mapping_index = tree.query(r_vertices_rad, k=1) |
185 | 210 | neighbour_pairs = create_node_neighbours_list(graph, node_neighbours, vertex_mapping_index) |
| 211 | + LOGGER.debug("Adding %d edges for resolution %d", len(neighbour_pairs), resolution) |
| 212 | + #LOGGER.debug("Sample edges: %s", neighbour_pairs) |
186 | 213 | graph.add_edges_from(neighbour_pairs) |
187 | 214 | return graph |
188 | 215 |
|
|
0 commit comments