Skip to content

Commit c40e239

Browse files
authored
Merge pull request #1274 from UXARRAY/zedwick/parallelize_dual
Parallelize Dual Grid Construction
2 parents e88b1da + 09af68f commit c40e239

File tree

2 files changed

+81
-48
lines changed

2 files changed

+81
-48
lines changed

uxarray/grid/dual.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from numba import njit
2+
from numba import njit, prange
33

44
from uxarray.constants import INT_DTYPE, INT_FILL_VALUE
55

@@ -26,7 +26,6 @@ def construct_dual(grid):
2626
dual_node_z = grid.face_z.values
2727

2828
# Get other information from the grid needed
29-
n_node = grid.n_node
3029
node_x = grid.node_x.values
3130
node_y = grid.node_y.values
3231
node_z = grid.node_z.values
@@ -35,10 +34,18 @@ def construct_dual(grid):
3534
# Get an array with the number of edges for each face
3635
n_edges_mask = node_face_connectivity != INT_FILL_VALUE
3736
n_edges = np.sum(n_edges_mask, axis=1)
37+
max_edges = node_face_connectivity.shape[1]
38+
39+
# Only nodes with 3+ edges can form valid dual faces
40+
valid_node_indices = np.where(n_edges >= 3)[0]
41+
42+
construct_node_face_connectivity = np.full(
43+
(len(valid_node_indices), max_edges), INT_FILL_VALUE, dtype=INT_DTYPE
44+
)
3845

3946
# Construct and return the faces
4047
new_node_face_connectivity = construct_faces(
41-
n_node,
48+
valid_node_indices,
4249
n_edges,
4350
dual_node_x,
4451
dual_node_y,
@@ -47,14 +54,16 @@ def construct_dual(grid):
4754
node_x,
4855
node_y,
4956
node_z,
57+
construct_node_face_connectivity,
58+
max_edges,
5059
)
5160

5261
return new_node_face_connectivity
5362

5463

55-
@njit(cache=True)
64+
@njit(cache=True, parallel=True)
5665
def construct_faces(
57-
n_node,
66+
valid_node_indices,
5867
n_edges,
5968
dual_node_x,
6069
dual_node_y,
@@ -63,61 +72,67 @@ def construct_faces(
6372
node_x,
6473
node_y,
6574
node_z,
75+
construct_node_face_connectivity,
76+
max_edges,
6677
):
6778
"""Construct the faces of the dual mesh based on a given
6879
node_face_connectivity.
6980
7081
Parameters
7182
----------
72-
n_node: np.ndarray
73-
number of nodes in the primal mesh
83+
valid_node_indices: np.ndarray
84+
Array of node indices with at least 3 connections in the primal mesh
7485
n_edges: np.ndarray
75-
array of the number of edges for each dual face
86+
Array of the number of edges for each node in the primal mesh
7687
dual_node_x: np.ndarray
77-
x node coordinates for the dual mesh
88+
x coordinates for the dual mesh nodes (face centers of primal mesh)
7889
dual_node_y: np.ndarray
79-
y node coordinates for the dual mesh
90+
y coordinates for the dual mesh nodes (face centers of primal mesh)
8091
dual_node_z: np.ndarray
81-
z node coordinates for the dual mesh
92+
z coordinates for the dual mesh nodes (face centers of primal mesh)
8293
node_face_connectivity: np.ndarray
83-
`node_face_connectivity` of the primal mesh
94+
Node-to-face connectivity of the primal mesh
8495
node_x: np.ndarray
85-
x node coordinates from the primal mesh
96+
x coordinates of nodes from the primal mesh
8697
node_y: np.ndarray
87-
y node coordinates from the primal mesh
98+
y coordinates of nodes from the primal mesh
8899
node_z: np.ndarray
89-
z node coordinates from the primal mesh
100+
z coordinates of nodes from the primal mesh
101+
construct_node_face_connectivity: np.ndarray
102+
Pre-allocated array to store the dual mesh connectivity
103+
max_edges: int
104+
The max number of edges in a face
90105
91106
92107
Returns
93108
--------
94-
node_face_connectivity : ndarray
109+
construct_node_face_connectivity : ndarray
95110
Constructed node_face_connectivity for the dual mesh
111+
112+
Notes
113+
-----
114+
In dual mesh construction, the "valid node indices" are face indices from
115+
the primal mesh's node_face_connectivity that are not fill values. These
116+
represent the actual faces that each primal node connects to, which become
117+
the nodes of the dual mesh faces.
96118
"""
97-
correction = 0
98-
max_edges = len(node_face_connectivity[0])
99-
construct_node_face_connectivity = np.full(
100-
(np.sum(n_edges > 2), max_edges), INT_FILL_VALUE, dtype=INT_DTYPE
101-
)
102-
for i in range(n_node):
103-
# If we have less than 3 edges, we can't construct anything but a line
104-
if n_edges[i] < 3:
105-
correction += 1
106-
continue
119+
n_valid = valid_node_indices.shape[0]
120+
121+
for out_idx in prange(n_valid):
122+
i = valid_node_indices[out_idx]
107123

108124
# Construct temporary face to hold unordered face nodes
109125
temp_face = np.array(
110126
[INT_FILL_VALUE for _ in range(n_edges[i])], dtype=INT_DTYPE
111127
)
112128

113-
# Get a list of the valid non fill value nodes
114-
valid_node_indices = node_face_connectivity[i][0 : n_edges[i]]
115-
index = 0
129+
# Get the face indices this node connects to (these become dual face nodes)
130+
connected_faces = node_face_connectivity[i][0 : n_edges[i]]
116131

117132
# Connect the face centers around the node to make dual face
118-
for node_idx in valid_node_indices:
119-
temp_face[index] = node_idx
120-
index += 1
133+
for index, node_idx in enumerate(connected_faces):
134+
if node_idx != INT_FILL_VALUE:
135+
temp_face[index] = node_idx
121136

122137
# Order the nodes using the angles so the faces have nodes in counter-clockwise sequence
123138
node_central = np.array([node_x[i], node_y[i], node_z[i]])
@@ -130,7 +145,7 @@ def construct_faces(
130145
)
131146

132147
# Order the face nodes properly in a counter-clockwise fashion
133-
if temp_face[0] is not INT_FILL_VALUE:
148+
if temp_face[0] != INT_FILL_VALUE:
134149
_face = _order_nodes(
135150
temp_face,
136151
node_0,
@@ -141,7 +156,8 @@ def construct_faces(
141156
dual_node_z,
142157
max_edges,
143158
)
144-
construct_node_face_connectivity[i - correction] = _face
159+
construct_node_face_connectivity[out_idx] = _face
160+
145161
return construct_node_face_connectivity
146162

147163

@@ -183,10 +199,18 @@ def _order_nodes(
183199
final_face : np.ndarray
184200
The face in proper counter-clockwise order
185201
"""
202+
# Add numerical stability check for degenerate cases
203+
if n_edges < 3:
204+
return np.full(max_edges, INT_FILL_VALUE, dtype=INT_DTYPE)
205+
186206
node_zero = node_0 - node_central
207+
node_zero_mag = np.linalg.norm(node_zero)
208+
209+
# Check for numerical stability
210+
if node_zero_mag < 1e-15:
211+
return np.full(max_edges, INT_FILL_VALUE, dtype=INT_DTYPE)
187212

188213
node_cross = np.cross(node_0, node_central)
189-
node_zero_mag = np.linalg.norm(node_zero)
190214

191215
d_angles = np.zeros(n_edges, dtype=np.float64)
192216
d_angles[0] = 0.0
@@ -205,11 +229,16 @@ def _order_nodes(
205229
node_diff = sub - node_central
206230
node_diff_mag = np.linalg.norm(node_diff)
207231

232+
# Skip if node difference is too small (numerical stability)
233+
if node_diff_mag < 1e-15:
234+
d_angles[j] = 0.0
235+
continue
236+
208237
d_side = np.dot(node_cross, node_diff)
209238
d_dot_norm = np.dot(node_zero, node_diff) / (node_zero_mag * node_diff_mag)
210239

211-
if d_dot_norm > 1.0:
212-
d_dot_norm = 1.0
240+
# Clamp to valid range for arccos to avoid numerical errors
241+
d_dot_norm = max(-1.0, min(1.0, d_dot_norm))
213242

214243
d_angles[j] = np.arccos(d_dot_norm)
215244

uxarray/remap/bilinear.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import xarray as xr
7-
from numba import njit
7+
from numba import njit, prange
88

99
if TYPE_CHECKING:
1010
from uxarray.core.dataarray import UxDataArray
@@ -140,14 +140,15 @@ def _barycentric_weights(point_xyz, dual, data_size, source_grid):
140140
dual.node_z.values,
141141
dual.face_node_connectivity.values,
142142
dual.n_nodes_per_face.values,
143+
dual.n_face,
143144
all_weights,
144145
all_indices,
145146
)
146147

147148
return all_weights, all_indices
148149

149150

150-
@njit(cache=True)
151+
@njit(cache=True, parallel=True)
151152
def _calculate_weights(
152153
valid_idxs,
153154
point_xyz,
@@ -157,13 +158,16 @@ def _calculate_weights(
157158
z,
158159
face_node_conn,
159160
n_nodes_per_face,
161+
n_faces,
160162
all_weights,
161163
all_indices,
162164
):
163-
for idx in valid_idxs:
164-
fidx = int(face_indices[idx, 0])
165+
for idx in prange(len(valid_idxs)):
166+
fidx = int(face_indices[valid_idxs[idx], 0])
167+
# bounds check: ensure face index is within valid range (0 to n_faces-1)
168+
if fidx < 0 or fidx >= n_faces:
169+
continue
165170
nverts = int(n_nodes_per_face[fidx])
166-
167171
polygon_xyz = np.zeros((nverts, 3), dtype=np.float64)
168172
polygon_face_indices = np.empty(nverts, dtype=np.int32)
169173
for j in range(nverts):
@@ -174,18 +178,18 @@ def _calculate_weights(
174178
polygon_face_indices[j] = node
175179

176180
# snap check
177-
match = _find_matching_node_index(polygon_xyz, point_xyz[idx])
181+
match = _find_matching_node_index(polygon_xyz, point_xyz[valid_idxs[idx]])
178182
if match[0] != -1:
179-
all_weights[idx, 0] = 1.0
180-
all_indices[idx, 0] = polygon_face_indices[match[0]]
183+
all_weights[valid_idxs[idx], 0] = 1.0
184+
all_indices[valid_idxs[idx], 0] = polygon_face_indices[match[0]]
181185
continue
182186

183187
weights, node_idxs = barycentric_coordinates_cartesian(
184-
polygon_xyz, point_xyz[idx]
188+
polygon_xyz, point_xyz[valid_idxs[idx]]
185189
)
186190
for k in range(len(weights)):
187-
all_weights[idx, k] = weights[k]
188-
all_indices[idx, k] = polygon_face_indices[node_idxs[k]]
191+
all_weights[valid_idxs[idx], k] = weights[k]
192+
all_indices[valid_idxs[idx], k] = polygon_face_indices[node_idxs[k]]
189193

190194

191195
@njit(cache=True)

0 commit comments

Comments
 (0)