Skip to content

Commit 45fd932

Browse files
committed
FIX: Move to utils
1 parent 5e1e954 commit 45fd932

File tree

2 files changed

+73
-71
lines changed

2 files changed

+73
-71
lines changed

surfer/utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class Surface(object):
3939
The vertices coordinates
4040
faces : 2d array
4141
The faces ie. the triangles
42+
nn : 2d array
43+
Normalized surface normals for vertices.
4244
subjects_dir : str | None
4345
If not None, this directory will be used as the subjects directory
4446
instead of the value set using the SUBJECTS_DIR environment variable.
@@ -80,6 +82,7 @@ def load_geometry(self):
8082
self.coords[:, 0] -= (np.max(self.coords[:, 0]) + self.offset)
8183
else:
8284
self.coords[:, 0] -= (np.min(self.coords[:, 0]) + self.offset)
85+
self.nn = _compute_normals(self.coords, self.faces)
8386

8487
def save_geometry(self):
8588
surf_path = op.join(self.data_path, "surf",
@@ -128,6 +131,75 @@ def apply_xfm(self, mtx):
128131
mtx.T)[:, :3]
129132

130133

134+
def _fast_cross_3d(x, y):
135+
"""Compute cross product between list of 3D vectors
136+
137+
Much faster than np.cross() when the number of cross products
138+
becomes large (>500). This is because np.cross() methods become
139+
less memory efficient at this stage.
140+
141+
Parameters
142+
----------
143+
x : array
144+
Input array 1.
145+
y : array
146+
Input array 2.
147+
148+
Returns
149+
-------
150+
z : array
151+
Cross product of x and y.
152+
153+
Notes
154+
-----
155+
x and y must both be 2D row vectors. One must have length 1, or both
156+
lengths must match.
157+
"""
158+
assert x.ndim == 2
159+
assert y.ndim == 2
160+
assert x.shape[1] == 3
161+
assert y.shape[1] == 3
162+
assert (x.shape[0] == 1 or y.shape[0] == 1) or x.shape[0] == y.shape[0]
163+
if max([x.shape[0], y.shape[0]]) >= 500:
164+
return np.c_[x[:, 1] * y[:, 2] - x[:, 2] * y[:, 1],
165+
x[:, 2] * y[:, 0] - x[:, 0] * y[:, 2],
166+
x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]]
167+
else:
168+
return np.cross(x, y)
169+
170+
171+
def _compute_normals(rr, tris):
172+
"""Efficiently compute vertex normals for triangulated surface"""
173+
# first, compute triangle normals
174+
t0 = time.time()
175+
r1 = rr[tris[:, 0], :]
176+
r2 = rr[tris[:, 1], :]
177+
r3 = rr[tris[:, 2], :]
178+
tri_nn = _fast_cross_3d((r2 - r1), (r3 - r1))
179+
180+
# Triangle normals and areas
181+
size = np.sqrt(np.sum(tri_nn * tri_nn, axis=1))
182+
zidx = np.where(size == 0)[0]
183+
size[zidx] = 1.0 # prevent ugly divide-by-zero
184+
tri_nn /= size[:, np.newaxis]
185+
186+
npts = len(rr)
187+
188+
# the following code replaces this, but is faster (vectorized):
189+
#
190+
# for p, verts in enumerate(tris):
191+
# nn[verts, :] += tri_nn[p, :]
192+
#
193+
nn = np.zeros((npts, 3))
194+
for verts in tris.T: # note this only loops 3x (number of verts per tri)
195+
counts = np.bincount(verts, minlength=npts)
196+
reord = np.argsort(verts)
197+
vals = np.r_[np.zeros((1, 3)), np.cumsum(tri_nn[reord, :], 0)]
198+
idx = np.cumsum(np.r_[0, counts])
199+
nn += vals[idx[1:], :] - vals[idx[:-1], :]
200+
return nn
201+
202+
131203
###############################################################################
132204
# LOGGING (courtesy of mne-python)
133205

surfer/viz.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,8 +2023,7 @@ def __init__(self, subject_id, hemi, surf, figure, geo, curv, title,
20232023
self._geo_mesh = mlab.pipeline.triangular_mesh_source(x, y, z, f,
20242024
**meshargs)
20252025
# add surface normals
2026-
nn = _compute_normals(x, y, z, f)
2027-
self._geo_mesh.data.point_data.normals = nn
2026+
self._geo_mesh.data.point_data.normals = self._geo.nn
20282027
self._geo_mesh.data.cell_data.normals = None
20292028
self._geo_surf = mlab.pipeline.surface(self._geo_mesh,
20302029
figure=self._f, reset_zoom=True,
@@ -2422,75 +2421,6 @@ def _format_colorbar(self):
24222421
self.neg_bar.scalar_bar_representation.position2 = (0.42, 0.09)
24232422

24242423

2425-
def _fast_cross_3d(x, y):
2426-
"""Compute cross product between list of 3D vectors
2427-
2428-
Much faster than np.cross() when the number of cross products
2429-
becomes large (>500). This is because np.cross() methods become
2430-
less memory efficient at this stage.
2431-
2432-
Parameters
2433-
----------
2434-
x : array
2435-
Input array 1.
2436-
y : array
2437-
Input array 2.
2438-
2439-
Returns
2440-
-------
2441-
z : array
2442-
Cross product of x and y.
2443-
2444-
Notes
2445-
-----
2446-
x and y must both be 2D row vectors. One must have length 1, or both
2447-
lengths must match.
2448-
"""
2449-
assert x.ndim == 2
2450-
assert y.ndim == 2
2451-
assert x.shape[1] == 3
2452-
assert y.shape[1] == 3
2453-
assert (x.shape[0] == 1 or y.shape[0] == 1) or x.shape[0] == y.shape[0]
2454-
if max([x.shape[0], y.shape[0]]) >= 500:
2455-
return np.c_[x[:, 1] * y[:, 2] - x[:, 2] * y[:, 1],
2456-
x[:, 2] * y[:, 0] - x[:, 0] * y[:, 2],
2457-
x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]]
2458-
else:
2459-
return np.cross(x, y)
2460-
2461-
2462-
def _compute_normals(x, y, z, tris):
2463-
"""Efficiently compute vertex normals for triangulated surface"""
2464-
# first, compute triangle normals
2465-
rr = np.array([x, y, z]).T
2466-
r1 = rr[tris[:, 0], :]
2467-
r2 = rr[tris[:, 1], :]
2468-
r3 = rr[tris[:, 2], :]
2469-
tri_nn = _fast_cross_3d((r2 - r1), (r3 - r1))
2470-
2471-
# Triangle normals and areas
2472-
size = np.sqrt(np.sum(tri_nn * tri_nn, axis=1))
2473-
zidx = np.where(size == 0)[0]
2474-
size[zidx] = 1.0 # prevent ugly divide-by-zero
2475-
tri_nn /= size[:, np.newaxis]
2476-
2477-
npts = len(rr)
2478-
2479-
# the following code replaces this, but is faster (vectorized):
2480-
#
2481-
# for p, verts in enumerate(tris):
2482-
# nn[verts, :] += tri_nn[p, :]
2483-
#
2484-
nn = np.zeros((npts, 3))
2485-
for verts in tris.T: # note this only loops 3x (number of verts per tri)
2486-
counts = np.bincount(verts, minlength=npts)
2487-
reord = np.argsort(verts)
2488-
vals = np.r_[np.zeros((1, 3)), np.cumsum(tri_nn[reord, :], 0)]
2489-
idx = np.cumsum(np.r_[0, counts])
2490-
nn += vals[idx[1:], :] - vals[idx[:-1], :]
2491-
return nn
2492-
2493-
24942424
class TimeViewer(HasTraits):
24952425
"""TimeViewer object providing a GUI for visualizing time series
24962426

0 commit comments

Comments
 (0)