Skip to content

Commit 2e1b1db

Browse files
committed
ENH: Rasterize plots, show lateral and medial views
1 parent d92ed3f commit 2e1b1db

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

niworkflows/viz/plots.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,19 +1048,38 @@ def get_surface_meshes(density, surface_type):
10481048
if clip_range:
10491049
lh_data = np.clip(lh_data, clip_range[0], clip_range[1], out=lh_data)
10501050
rh_data = np.clip(rh_data, clip_range[0], clip_range[1], out=rh_data)
1051+
mn, mx = clip_range
1052+
else:
1053+
mn, mx = None, None
1054+
1055+
if mn is None:
1056+
mn = np.min(data)
1057+
if mx is None:
1058+
mx = np.max(data)
10511059

10521060
cmap = kwargs.pop('cmap', 'YlOrRd_r')
10531061

10541062
# Build the figure
10551063
lh_mesh, rh_mesh = get_surface_meshes(density, surface_type)
1056-
figure = plt.figure(figsize=plt.figaspect(0.5))
1057-
ax0 = figure.add_subplot(1, 2, 1, projection='3d')
1058-
plot_surf(lh_mesh, lh_data, cmap=cmap, axes=ax0, **kwargs)
1059-
ax1 = figure.add_subplot(1, 2, 2, projection='3d')
1060-
plot_surf(lh_mesh, lh_data, cmap=cmap, axes=ax1, **kwargs)
1064+
figure = plt.figure(figsize=plt.figaspect(0.25))
1065+
ax00 = figure.add_subplot(1, 4, 1, projection='3d', rasterized=True)
1066+
plot_surf(lh_mesh, lh_data, hemi='left', view='lateral', cmap=cmap, axes=ax00, figure=figure, **kwargs)
1067+
ax00.dist = 7
1068+
ax01 = figure.add_subplot(1, 4, 2, projection='3d', rasterized=True)
1069+
plot_surf(rh_mesh, rh_data, hemi='right', view='lateral', cmap=cmap, axes=ax01, figure=figure, **kwargs)
1070+
ax01.dist = 7
1071+
ax10 = figure.add_subplot(1, 4, 3, projection='3d', rasterized=True)
1072+
plot_surf(lh_mesh, lh_data, hemi='left', view='medial', cmap=cmap, axes=ax10, figure=figure, **kwargs)
1073+
ax10.dist = 7
1074+
ax11 = figure.add_subplot(1, 4, 4, projection='3d', rasterized=True)
1075+
plot_surf(rh_mesh, rh_data, hemi='right', view='medial', cmap=cmap, axes=ax11, figure=figure, **kwargs)
1076+
ax11.dist = 7
1077+
1078+
mappable = cm.ScalarMappable(norm=Normalize(mn, mx), cmap=cmap)
1079+
figure.colorbar(mappable, shrink=0.2, ax=figure.axes, location='bottom')
10611080

10621081
if output_file is not None:
1063-
figure.savefig(output_file, bbox_inches="tight")
1082+
figure.savefig(output_file, bbox_inches="tight", dpi=400)
10641083
plt.close(figure)
10651084
return output_file
10661085

0 commit comments

Comments
 (0)