Skip to content

Commit 5e1e954

Browse files
committed
ENH: Add vertex normals for smooth surface
1 parent 3bb7827 commit 5e1e954

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

surfer/viz.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,6 +2022,10 @@ def __init__(self, subject_id, hemi, surf, figure, geo, curv, title,
20222022
x, y, z, f = self._geo.x, self._geo.y, self._geo.z, self._geo.faces
20232023
self._geo_mesh = mlab.pipeline.triangular_mesh_source(x, y, z, f,
20242024
**meshargs)
2025+
# add surface normals
2026+
nn = _compute_normals(x, y, z, f)
2027+
self._geo_mesh.data.point_data.normals = nn
2028+
self._geo_mesh.data.cell_data.normals = None
20252029
self._geo_surf = mlab.pipeline.surface(self._geo_mesh,
20262030
figure=self._f, reset_zoom=True,
20272031
**kwargs)
@@ -2418,6 +2422,75 @@ def _format_colorbar(self):
24182422
self.neg_bar.scalar_bar_representation.position2 = (0.42, 0.09)
24192423

24202424

2425+
def _fast_cross_3d(x, y):
2426+
"""Compute cross product between list of 3D vectors
2427+
2428+
Much faster than np.cross() when the number of cross products
2429+
becomes large (>500). This is because np.cross() methods become
2430+
less memory efficient at this stage.
2431+
2432+
Parameters
2433+
----------
2434+
x : array
2435+
Input array 1.
2436+
y : array
2437+
Input array 2.
2438+
2439+
Returns
2440+
-------
2441+
z : array
2442+
Cross product of x and y.
2443+
2444+
Notes
2445+
-----
2446+
x and y must both be 2D row vectors. One must have length 1, or both
2447+
lengths must match.
2448+
"""
2449+
assert x.ndim == 2
2450+
assert y.ndim == 2
2451+
assert x.shape[1] == 3
2452+
assert y.shape[1] == 3
2453+
assert (x.shape[0] == 1 or y.shape[0] == 1) or x.shape[0] == y.shape[0]
2454+
if max([x.shape[0], y.shape[0]]) >= 500:
2455+
return np.c_[x[:, 1] * y[:, 2] - x[:, 2] * y[:, 1],
2456+
x[:, 2] * y[:, 0] - x[:, 0] * y[:, 2],
2457+
x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]]
2458+
else:
2459+
return np.cross(x, y)
2460+
2461+
2462+
def _compute_normals(x, y, z, tris):
2463+
"""Efficiently compute vertex normals for triangulated surface"""
2464+
# first, compute triangle normals
2465+
rr = np.array([x, y, z]).T
2466+
r1 = rr[tris[:, 0], :]
2467+
r2 = rr[tris[:, 1], :]
2468+
r3 = rr[tris[:, 2], :]
2469+
tri_nn = _fast_cross_3d((r2 - r1), (r3 - r1))
2470+
2471+
# Triangle normals and areas
2472+
size = np.sqrt(np.sum(tri_nn * tri_nn, axis=1))
2473+
zidx = np.where(size == 0)[0]
2474+
size[zidx] = 1.0 # prevent ugly divide-by-zero
2475+
tri_nn /= size[:, np.newaxis]
2476+
2477+
npts = len(rr)
2478+
2479+
# the following code replaces this, but is faster (vectorized):
2480+
#
2481+
# for p, verts in enumerate(tris):
2482+
# nn[verts, :] += tri_nn[p, :]
2483+
#
2484+
nn = np.zeros((npts, 3))
2485+
for verts in tris.T: # note this only loops 3x (number of verts per tri)
2486+
counts = np.bincount(verts, minlength=npts)
2487+
reord = np.argsort(verts)
2488+
vals = np.r_[np.zeros((1, 3)), np.cumsum(tri_nn[reord, :], 0)]
2489+
idx = np.cumsum(np.r_[0, counts])
2490+
nn += vals[idx[1:], :] - vals[idx[:-1], :]
2491+
return nn
2492+
2493+
24212494
class TimeViewer(HasTraits):
24222495
"""TimeViewer object providing a GUI for visualizing time series
24232496

0 commit comments

Comments
 (0)