11import numpy as np
2- from numba import njit
2+ from numba import njit , prange
33
44from uxarray .constants import INT_DTYPE , INT_FILL_VALUE
55
@@ -26,7 +26,6 @@ def construct_dual(grid):
2626 dual_node_z = grid .face_z .values
2727
2828 # Get other information from the grid needed
29- n_node = grid .n_node
3029 node_x = grid .node_x .values
3130 node_y = grid .node_y .values
3231 node_z = grid .node_z .values
@@ -35,10 +34,18 @@ def construct_dual(grid):
3534 # Get an array with the number of edges for each face
3635 n_edges_mask = node_face_connectivity != INT_FILL_VALUE
3736 n_edges = np .sum (n_edges_mask , axis = 1 )
37+ max_edges = node_face_connectivity .shape [1 ]
38+
39+ # Only nodes with 3+ edges can form valid dual faces
40+ valid_node_indices = np .where (n_edges >= 3 )[0 ]
41+
42+ construct_node_face_connectivity = np .full (
43+ (len (valid_node_indices ), max_edges ), INT_FILL_VALUE , dtype = INT_DTYPE
44+ )
3845
3946 # Construct and return the faces
4047 new_node_face_connectivity = construct_faces (
41- n_node ,
48+ valid_node_indices ,
4249 n_edges ,
4350 dual_node_x ,
4451 dual_node_y ,
@@ -47,14 +54,16 @@ def construct_dual(grid):
4754 node_x ,
4855 node_y ,
4956 node_z ,
57+ construct_node_face_connectivity ,
58+ max_edges ,
5059 )
5160
5261 return new_node_face_connectivity
5362
5463
55- @njit (cache = True )
64+ @njit (cache = True , parallel = True )
5665def construct_faces (
57- n_node ,
66+ valid_node_indices ,
5867 n_edges ,
5968 dual_node_x ,
6069 dual_node_y ,
@@ -63,61 +72,67 @@ def construct_faces(
6372 node_x ,
6473 node_y ,
6574 node_z ,
75+ construct_node_face_connectivity ,
76+ max_edges ,
6677):
6778 """Construct the faces of the dual mesh based on a given
6879 node_face_connectivity.
6980
7081 Parameters
7182 ----------
72- n_node : np.ndarray
73- number of nodes in the primal mesh
83+ valid_node_indices : np.ndarray
84+ Array of node indices with at least 3 connections in the primal mesh
7485 n_edges: np.ndarray
75- array of the number of edges for each dual face
86+ Array of the number of edges for each node in the primal mesh
7687 dual_node_x: np.ndarray
77- x node coordinates for the dual mesh
88+ x coordinates for the dual mesh nodes (face centers of primal mesh)
7889 dual_node_y: np.ndarray
79- y node coordinates for the dual mesh
90+ y coordinates for the dual mesh nodes (face centers of primal mesh)
8091 dual_node_z: np.ndarray
81- z node coordinates for the dual mesh
92+ z coordinates for the dual mesh nodes (face centers of primal mesh)
8293 node_face_connectivity: np.ndarray
83- `node_face_connectivity` of the primal mesh
94+ Node-to-face connectivity of the primal mesh
8495 node_x: np.ndarray
85- x node coordinates from the primal mesh
96+ x coordinates of nodes from the primal mesh
8697 node_y: np.ndarray
87- y node coordinates from the primal mesh
98+ y coordinates of nodes from the primal mesh
8899 node_z: np.ndarray
89- z node coordinates from the primal mesh
100+ z coordinates of nodes from the primal mesh
101+ construct_node_face_connectivity: np.ndarray
102+ Pre-allocated array to store the dual mesh connectivity
103+ max_edges: int
104+ The max number of edges in a face
90105
91106
92107 Returns
93108 --------
94- node_face_connectivity : ndarray
109+ construct_node_face_connectivity : ndarray
95110 Constructed node_face_connectivity for the dual mesh
111+
112+ Notes
113+ -----
114+ In dual mesh construction, the "valid node indices" are face indices from
115+ the primal mesh's node_face_connectivity that are not fill values. These
116+ represent the actual faces that each primal node connects to, which become
117+ the nodes of the dual mesh faces.
96118 """
97- correction = 0
98- max_edges = len (node_face_connectivity [0 ])
99- construct_node_face_connectivity = np .full (
100- (np .sum (n_edges > 2 ), max_edges ), INT_FILL_VALUE , dtype = INT_DTYPE
101- )
102- for i in range (n_node ):
103- # If we have less than 3 edges, we can't construct anything but a line
104- if n_edges [i ] < 3 :
105- correction += 1
106- continue
119+ n_valid = valid_node_indices .shape [0 ]
120+
121+ for out_idx in prange (n_valid ):
122+ i = valid_node_indices [out_idx ]
107123
108124 # Construct temporary face to hold unordered face nodes
109125 temp_face = np .array (
110126 [INT_FILL_VALUE for _ in range (n_edges [i ])], dtype = INT_DTYPE
111127 )
112128
113- # Get a list of the valid non fill value nodes
114- valid_node_indices = node_face_connectivity [i ][0 : n_edges [i ]]
115- index = 0
129+ # Get the face indices this node connects to (these become dual face nodes)
130+ connected_faces = node_face_connectivity [i ][0 : n_edges [i ]]
116131
117132 # Connect the face centers around the node to make dual face
118- for node_idx in valid_node_indices :
119- temp_face [ index ] = node_idx
120- index += 1
133+ for index , node_idx in enumerate ( connected_faces ) :
134+ if node_idx != INT_FILL_VALUE :
135+ temp_face [ index ] = node_idx
121136
122137 # Order the nodes using the angles so the faces have nodes in counter-clockwise sequence
123138 node_central = np .array ([node_x [i ], node_y [i ], node_z [i ]])
@@ -130,7 +145,7 @@ def construct_faces(
130145 )
131146
132147 # Order the face nodes properly in a counter-clockwise fashion
133- if temp_face [0 ] is not INT_FILL_VALUE :
148+ if temp_face [0 ] != INT_FILL_VALUE :
134149 _face = _order_nodes (
135150 temp_face ,
136151 node_0 ,
@@ -141,7 +156,8 @@ def construct_faces(
141156 dual_node_z ,
142157 max_edges ,
143158 )
144- construct_node_face_connectivity [i - correction ] = _face
159+ construct_node_face_connectivity [out_idx ] = _face
160+
145161 return construct_node_face_connectivity
146162
147163
@@ -183,10 +199,18 @@ def _order_nodes(
183199 final_face : np.ndarray
184200 The face in proper counter-clockwise order
185201 """
202+ # Add numerical stability check for degenerate cases
203+ if n_edges < 3 :
204+ return np .full (max_edges , INT_FILL_VALUE , dtype = INT_DTYPE )
205+
186206 node_zero = node_0 - node_central
207+ node_zero_mag = np .linalg .norm (node_zero )
208+
209+ # Check for numerical stability
210+ if node_zero_mag < 1e-15 :
211+ return np .full (max_edges , INT_FILL_VALUE , dtype = INT_DTYPE )
187212
188213 node_cross = np .cross (node_0 , node_central )
189- node_zero_mag = np .linalg .norm (node_zero )
190214
191215 d_angles = np .zeros (n_edges , dtype = np .float64 )
192216 d_angles [0 ] = 0.0
@@ -205,11 +229,16 @@ def _order_nodes(
205229 node_diff = sub - node_central
206230 node_diff_mag = np .linalg .norm (node_diff )
207231
232+ # Skip if node difference is too small (numerical stability)
233+ if node_diff_mag < 1e-15 :
234+ d_angles [j ] = 0.0
235+ continue
236+
208237 d_side = np .dot (node_cross , node_diff )
209238 d_dot_norm = np .dot (node_zero , node_diff ) / (node_zero_mag * node_diff_mag )
210239
211- if d_dot_norm > 1.0 :
212- d_dot_norm = 1.0
240+ # Clamp to valid range for arccos to avoid numerical errors
241+ d_dot_norm = max ( - 1.0 , min ( 1.0 , d_dot_norm ))
213242
214243 d_angles [j ] = np .arccos (d_dot_norm )
215244
0 commit comments