@@ -2022,6 +2022,10 @@ def __init__(self, subject_id, hemi, surf, figure, geo, curv, title,
2022
2022
x , y , z , f = self ._geo .x , self ._geo .y , self ._geo .z , self ._geo .faces
2023
2023
self ._geo_mesh = mlab .pipeline .triangular_mesh_source (x , y , z , f ,
2024
2024
** meshargs )
2025
+ # add surface normals
2026
+ nn = _compute_normals (x , y , z , f )
2027
+ self ._geo_mesh .data .point_data .normals = nn
2028
+ self ._geo_mesh .data .cell_data .normals = None
2025
2029
self ._geo_surf = mlab .pipeline .surface (self ._geo_mesh ,
2026
2030
figure = self ._f , reset_zoom = True ,
2027
2031
** kwargs )
@@ -2418,6 +2422,75 @@ def _format_colorbar(self):
2418
2422
self .neg_bar .scalar_bar_representation .position2 = (0.42 , 0.09 )
2419
2423
2420
2424
2425
+ def _fast_cross_3d (x , y ):
2426
+ """Compute cross product between list of 3D vectors
2427
+
2428
+ Much faster than np.cross() when the number of cross products
2429
+ becomes large (>500). This is because np.cross() methods become
2430
+ less memory efficient at this stage.
2431
+
2432
+ Parameters
2433
+ ----------
2434
+ x : array
2435
+ Input array 1.
2436
+ y : array
2437
+ Input array 2.
2438
+
2439
+ Returns
2440
+ -------
2441
+ z : array
2442
+ Cross product of x and y.
2443
+
2444
+ Notes
2445
+ -----
2446
+ x and y must both be 2D row vectors. One must have length 1, or both
2447
+ lengths must match.
2448
+ """
2449
+ assert x .ndim == 2
2450
+ assert y .ndim == 2
2451
+ assert x .shape [1 ] == 3
2452
+ assert y .shape [1 ] == 3
2453
+ assert (x .shape [0 ] == 1 or y .shape [0 ] == 1 ) or x .shape [0 ] == y .shape [0 ]
2454
+ if max ([x .shape [0 ], y .shape [0 ]]) >= 500 :
2455
+ return np .c_ [x [:, 1 ] * y [:, 2 ] - x [:, 2 ] * y [:, 1 ],
2456
+ x [:, 2 ] * y [:, 0 ] - x [:, 0 ] * y [:, 2 ],
2457
+ x [:, 0 ] * y [:, 1 ] - x [:, 1 ] * y [:, 0 ]]
2458
+ else :
2459
+ return np .cross (x , y )
2460
+
2461
+
2462
+ def _compute_normals (x , y , z , tris ):
2463
+ """Efficiently compute vertex normals for triangulated surface"""
2464
+ # first, compute triangle normals
2465
+ rr = np .array ([x , y , z ]).T
2466
+ r1 = rr [tris [:, 0 ], :]
2467
+ r2 = rr [tris [:, 1 ], :]
2468
+ r3 = rr [tris [:, 2 ], :]
2469
+ tri_nn = _fast_cross_3d ((r2 - r1 ), (r3 - r1 ))
2470
+
2471
+ # Triangle normals and areas
2472
+ size = np .sqrt (np .sum (tri_nn * tri_nn , axis = 1 ))
2473
+ zidx = np .where (size == 0 )[0 ]
2474
+ size [zidx ] = 1.0 # prevent ugly divide-by-zero
2475
+ tri_nn /= size [:, np .newaxis ]
2476
+
2477
+ npts = len (rr )
2478
+
2479
+ # the following code replaces this, but is faster (vectorized):
2480
+ #
2481
+ # for p, verts in enumerate(tris):
2482
+ # nn[verts, :] += tri_nn[p, :]
2483
+ #
2484
+ nn = np .zeros ((npts , 3 ))
2485
+ for verts in tris .T : # note this only loops 3x (number of verts per tri)
2486
+ counts = np .bincount (verts , minlength = npts )
2487
+ reord = np .argsort (verts )
2488
+ vals = np .r_ [np .zeros ((1 , 3 )), np .cumsum (tri_nn [reord , :], 0 )]
2489
+ idx = np .cumsum (np .r_ [0 , counts ])
2490
+ nn += vals [idx [1 :], :] - vals [idx [:- 1 ], :]
2491
+ return nn
2492
+
2493
+
2421
2494
class TimeViewer (HasTraits ):
2422
2495
"""TimeViewer object providing a GUI for visualizing time series
2423
2496
0 commit comments