diff --git a/.gitattributes b/.gitattributes index f73ff28..802a4b3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,4 @@ -phys2bids/_version.py export-subst +nigsp/_version.py export-subst *.py eol=lf *.rst eol=lf diff --git a/nigsp/blocks.py b/nigsp/blocks.py index 5db5df9..6c67dd1 100644 --- a/nigsp/blocks.py +++ b/nigsp/blocks.py @@ -99,7 +99,7 @@ def export_metric(scgraph, outext, outprefix): return 0 -def plot_metric(scgraph, outprefix, atlas=None, thr=None): +def plot_metric(scgraph, outprefix, atlas=None, title=None, thr=None): """ If possible, plot metrics as markerplot. @@ -109,10 +109,10 @@ def plot_metric(scgraph, outprefix, atlas=None, thr=None): The internal object containing all data. outprefix : str The prefix of the png file to export - img : 3DNiftiImage or None, optional - The nifti image of the atlas atlas : 3D Nifti1Image, numpy.ndarray, or None, optional Either a nifti image containing a valid atlas or a set of parcel coordinates. + title : None or str, optional + Add a title to the graph thr : float or None, optional The threshold to use in plotting the nodes. """ @@ -135,7 +135,11 @@ def plot_metric(scgraph, outprefix, atlas=None, thr=None): if atlas_plot is not None: if scgraph.sdi is not None: viz.plot_nodes( - scgraph.sdi, atlas_plot, filename=f"{outprefix}sdi.png", thr=thr + scgraph.sdi, + atlas_plot, + filename=f"{outprefix}sdi.png", + title=title, + thr=thr, ) elif scgraph.gsdi is not None: for k in scgraph.gsdi.keys(): @@ -143,6 +147,7 @@ def plot_metric(scgraph, outprefix, atlas=None, thr=None): scgraph.gsdi[k], atlas_plot, filename=f"{outprefix}gsdi_{k}.png", + title=title, thr=thr, ) diff --git a/nigsp/objects.py b/nigsp/objects.py index 489fb10..da9716c 100644 --- a/nigsp/objects.py +++ b/nigsp/objects.py @@ -130,7 +130,7 @@ def compute_graph_energy(self, mean=False): # pragma: no cover ) return self - def split_graph(self, index=None, keys=["low", "high"]): + def split_graph(self, index=None, keys=["low-pass", "high-pass"]): """Implement timeseries.median_cutoff_frequency_idx as class method.""" if index is None: index = self.index diff --git a/nigsp/operations/timeseries.py b/nigsp/operations/timeseries.py index 9351222..d005513 100644 --- a/nigsp/operations/timeseries.py +++ b/nigsp/operations/timeseries.py @@ -205,7 +205,7 @@ def resize_ts(timeseries, resize=None, globally=False): if resize == "spc": # pragma: no cover LGR.info("Expressing timeseries in signal percentage change") timeseries = spc_ts(timeseries, globally=globally) - elif resize == "norm": # pragma: no cover + elif resize in ["norm", "zscore"]: # pragma: no cover LGR.info("Normalise timeseries") timeseries = normalise_ts(timeseries, globally=globally) elif resize == "demean": # pragma: no cover @@ -338,13 +338,13 @@ def median_cutoff_frequency_idx(energy): return freq_idx -def graph_filter(timeseries, eigenvec, freq_idx, keys=["low", "high"]): +def graph_filter(timeseries, eigenvec, freq_idx, keys=["low-pass", "high-pass"]): """ Filter a graph decomposition into two parts based on freq_idx. Return the two eigenvector lists (high freq and low freq) that are equal - to the original eigenvector list, but "low" is zero-ed for all frequencies - >= of the given index, and "high" is zero-ed for all frequencies < to the + to the original eigenvector list, but "low-pass" is zero-ed for all frequencies + >= of the given index, and "high-pass" is zero-ed for all frequencies < to the given index. Also return their projection onto a timeseries. @@ -357,7 +357,7 @@ def graph_filter(timeseries, eigenvec, freq_idx, keys=["low", "high"]): freq_idx : int or list The index of the frequency that splits the spectral power into two (more or less) equal parts - i.e. the index of the first frequency in - the "high" component. + the "high-pass" component. keys : list, optional The keys to call the split parts with @@ -371,8 +371,8 @@ def graph_filter(timeseries, eigenvec, freq_idx, keys=["low", "high"]): Raises ------ IndexError - If the given index is 0 (all "high"), the last possible index (all "low"), - or higher than the last possible index (not applicable). + If the given index is 0 (all "high-pass"), the last possible index + (all "low-pass"), or higher than the last possible index (not applicable). """ # #!# Find better name # #!# Implement an index splitter diff --git a/nigsp/tests/test_integration.py b/nigsp/tests/test_integration.py index 8584a4c..e409266 100644 --- a/nigsp/tests/test_integration.py +++ b/nigsp/tests/test_integration.py @@ -34,27 +34,27 @@ def test_integration(timeseries, sc_mtx, atlas, mean_fc, sdi, testdir): # Check that files were created assert isdir(testdir) assert isdir(join(testdir, "logs")) - assert isdir(join(testdir, "testfile_timeseries_low")) - assert isdir(join(testdir, "testfile_timeseries_high")) - assert isfile(join(testdir, "testfile_timeseries_low", "000.tsv")) - assert isfile(join(testdir, "testfile_timeseries_high", "000.tsv")) + assert isdir(join(testdir, "testfile_timeseries_low-pass")) + assert isdir(join(testdir, "testfile_timeseries_high-pass")) + assert isfile(join(testdir, "testfile_timeseries_low-pass", "000.tsv")) + assert isfile(join(testdir, "testfile_timeseries_high-pass", "000.tsv")) assert isfile(join(testdir, "testfile_fc.tsv")) - assert isfile(join(testdir, "testfile_fc_low.tsv")) - assert isfile(join(testdir, "testfile_fc_high.tsv")) + assert isfile(join(testdir, "testfile_fc_low-pass.tsv")) + assert isfile(join(testdir, "testfile_fc_high-pass.tsv")) assert isfile(join(testdir, "testfile_eigenval.tsv")) assert isfile(join(testdir, "testfile_eigenvec.tsv")) - assert isfile(join(testdir, "testfile_eigenvec_low.tsv")) - assert isfile(join(testdir, "testfile_eigenvec_high.tsv")) + assert isfile(join(testdir, "testfile_eigenvec_low-pass.tsv")) + assert isfile(join(testdir, "testfile_eigenvec_high-pass.tsv")) assert isfile(join(testdir, "testfile_sdi.tsv")) assert isfile(join(testdir, "testfile_mkd_sdi.tsv")) assert isfile(join(testdir, "testfile_laplacian.png")) assert isfile(join(testdir, "testfile_sc.png")) assert isfile(join(testdir, "testfile_fc.png")) - assert isfile(join(testdir, "testfile_fc_low.png")) - assert isfile(join(testdir, "testfile_fc_high.png")) + assert isfile(join(testdir, "testfile_fc_low-pass.png")) + assert isfile(join(testdir, "testfile_fc_high-pass.png")) assert isfile(join(testdir, "testfile_greyplot.png")) - assert isfile(join(testdir, "testfile_greyplot_low.png")) - assert isfile(join(testdir, "testfile_greyplot_high.png")) + assert isfile(join(testdir, "testfile_greyplot_low-pass.png")) + assert isfile(join(testdir, "testfile_greyplot_high-pass.png")) assert isfile(join(testdir, "testfile_sdi.png")) assert isfile(join(testdir, "testfile_mkd_sdi.png")) diff --git a/nigsp/tests/test_metrics.py b/nigsp/tests/test_metrics.py index 55d4b30..edf641e 100644 --- a/nigsp/tests/test_metrics.py +++ b/nigsp/tests/test_metrics.py @@ -16,11 +16,11 @@ def test_sdi(): ts3 = np.arange(5, 7)[..., np.newaxis] sdi_in = np.log2(np.arange(3.0, 1.0, -1.0)) - ts = {"low": ts1, "high": ts2} + ts = {"low-pass": ts1, "high-pass": ts2} sdi_out = metrics.sdi(ts) assert (sdi_out == sdi_in).all() - ts = {"HIGH": ts2, "LOW": ts1} + ts = {"HIGH-PASS": ts2, "LOW-PASS": ts1} sdi_out = metrics.sdi(ts) assert (sdi_out == sdi_in).all() @@ -29,8 +29,8 @@ def test_sdi(): assert (sdi_out == sdi_in).all() ts = { - "low": np.repeat(np.repeat(ts1[..., np.newaxis], 3, axis=1), 3, axis=2), - "high": np.repeat(np.repeat(ts2[..., np.newaxis], 3, axis=1), 3, axis=2), + "low-pass": np.repeat(np.repeat(ts1[..., np.newaxis], 3, axis=1), 3, axis=2), + "high-pass": np.repeat(np.repeat(ts2[..., np.newaxis], 3, axis=1), 3, axis=2), } sdi_out = metrics.sdi(ts, mean=True) sdi_out = np.around(sdi_out, decimals=15) @@ -80,7 +80,7 @@ def test_break_sdi(): ts = {"alpha": ts1, "beta": ts2, "gamma": ts3} with raises(ValueError) as errorinfo: - metrics.sdi(ts, keys=["high", "low"]) + metrics.sdi(ts, keys=["high-pass", "low-pass"]) assert "provided keys" in str(errorinfo.value) with raises(ValueError) as errorinfo: diff --git a/nigsp/viz.py b/nigsp/viz.py index 8b0bd1e..3b042a6 100644 --- a/nigsp/viz.py +++ b/nigsp/viz.py @@ -26,12 +26,12 @@ LGR = logging.getLogger(__name__) SET_DPI = 100 -FIGSIZE = (18, 10) +FIGSIZE = (12, 7) FIGSIZE_SQUARE = (6, 5) -FIGSIZE_LONG = (10, 5) +FIGSIZE_LONG = (12, 4) -def plot_connectivity(mtx, filename=None, closeplot=False): +def plot_connectivity(mtx, filename=None, title=None, crange=None, closeplot=False): """ Create a connectivity matrix plot. @@ -43,6 +43,10 @@ def plot_connectivity(mtx, filename=None, closeplot=False): A (square) array with connectivity information inside. filename : None, str, or os.PathLike, optional The path to save the plot on disk. + title : None or str, optional + Add a title to the graph + range : None or list, optional + Set vmin and vmax closeplot : bool, optional Whether to close plots after saving or not. Mainly used for debug or use with live python/ipython instances. @@ -85,11 +89,31 @@ def plot_connectivity(mtx, filename=None, closeplot=False): LGR.warning("Given matrix is not a square matrix!") LGR.info("Creating connectivity plot.") - plt.figure(figsize=FIGSIZE_SQUARE) - plot_matrix(mtx) + fig = plt.figure(figsize=FIGSIZE_SQUARE) + ax = fig.subplots() + + pc_args = {"mat": mtx, "axes": ax} + if crange is not None: + if type(crange) in [list, tuple]: + pc_args["vmin"] = crange[0] + pc_args["vmax"] = crange[1] + else: + vmax = np.nanpercentile(mtx, 98) # mtx.max() + vmin = np.abs(np.nanpercentile(mtx, 2)) # mtx.min() + if crange == "auto-symm" and mtx.min() < 0 and vmax > 0: + pc_args["vmax"] = vmax if vmax > vmin else vmin + pc_args["vmin"] = -vmin if vmin > vmax else -vmax + elif crange == "auto-zero" or mtx.min() > 0 or vmax < 0: + pass + else: + raise NotImplementedError(f"{crange} option not implemented.") + + plot_matrix(**pc_args) + if title is not None: + fig.suptitle(title) if filename is not None: - plt.savefig(filename, dpi=SET_DPI) + plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight") closeplot = True if closeplot: @@ -162,17 +186,18 @@ def plot_greyplot(timeseries, filename=None, title=None, resize=None, closeplot= timeseries = resize_ts(timeseries, resize) LGR.info("Creating greyplot.") - plt.figure(figsize=FIGSIZE_LONG) - if title is not None: - plt.title(title) vmax = np.percentile(timeseries, 99) vmin = np.percentile(timeseries, 1) - plt.imshow(timeseries, cmap="gray", vmin=vmin, vmax=vmax) - plt.colorbar() + fig = plt.figure(figsize=FIGSIZE_LONG) + ax = fig.subplots() + im = ax.imshow(timeseries, cmap="gray", vmin=vmin, vmax=vmax) + plt.colorbar(im, ax=ax) + if title is not None: + fig.suptitle(title) plt.tight_layout() if filename is not None: - plt.savefig(filename, dpi=SET_DPI) + plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight") closeplot = True if closeplot: @@ -183,7 +208,9 @@ def plot_greyplot(timeseries, filename=None, title=None, resize=None, closeplot= return 0 -def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False): +def plot_nodes( + ns, atlas, filename=None, title=None, thr=None, cmap=None, closeplot=False +): """ Create a marker plot in the MNI space. @@ -198,8 +225,12 @@ def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False): or a list of coordinates of the center of mass of parcels. filename : None, str, or os.PathLike, optional The path to save the plot on disk. + title : None or str, optional + Add a title to the graph thr : float or None, optional The threshold to use in plotting the nodes. + cmap : None or matplotlib.pyplot.cm colormap object, optional. + The colormap to adopt in plotting nodes. Defaults to reverse viridis. closeplot : bool, optional Whether to close plots after saving or not. Mainly used for debug or use with live python/ipython instances. @@ -252,11 +283,17 @@ def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False): raise ValueError("Node array and coordinates array have different length.") LGR.info("Creating markerplot.") - plt.figure(figsize=FIGSIZE) - plot_markers(ns, coord, node_threshold=thr) + fig = plt.figure(figsize=FIGSIZE) + ax = fig.subplots() + cmap = plt.cm.viridis_r if cmap is None else cmap + plot_markers(ns, coord, axes=ax, node_threshold=thr, node_cmap=cmap) + if title is not None: + fig.suptitle(title) + + plt.tight_layout() if filename is not None: - plt.savefig(filename, dpi=SET_DPI) + plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight") closeplot = True if closeplot: @@ -265,7 +302,9 @@ def plot_nodes(ns, atlas, filename=None, thr=None, closeplot=False): return 0 -def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False): +def plot_edges( + mtx, atlas, filename=None, title=None, thr=None, cmap=None, closeplot=False +): """ Create a connectivity plot in the MNI space. @@ -280,9 +319,13 @@ def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False): or a list of coordinates of the center of mass of parcels. filename : None, str, or os.PathLike, optional The path to save the plot on disk. + title : None or str, optional + Add a title to the graph thr : float, str or None, optional The threshold to use in plotting the nodes. If `str`, needs to express a percentage. + cmap : None or matplotlib.pyplot.cm colormap object, optional. + The colormap to adopt in plotting nodes. Defaults to reverse viridis. closeplot : bool, optional Whether to close plots after saving or not. Mainly used for debug or use with live python/ipython instances. @@ -335,7 +378,8 @@ def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False): raise ValueError("Matrix axis and coordinates array have different length.") LGR.info("Creating connectome-like plot.") - plt.figure(figsize=FIGSIZE) + fig = plt.figure(figsize=FIGSIZE) + ax = fig.subplots() pc_args = { "adjacency_matrix": mtx, @@ -343,7 +387,9 @@ def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False): "node_color": "black", "node_size": 5, "edge_threshold": thr, + "edge_cmap": plt.cm.bwr, "colorbar": True, + "axes": ax, } if mtx.min() >= 0: @@ -351,10 +397,15 @@ def plot_edges(mtx, atlas, filename=None, thr=None, closeplot=False): pc_args["edge_vmax"] = mtx.max() pc_args["edge_cmap"] = cm.red_transparent_full_alpha_range + pc_args["edge_cmap"] = pc_args["edge_cmap"] if cmap is None else cmap plot_connectome(**pc_args) + if title is not None: + fig.suptitle(title) + + plt.tight_layout() if filename is not None: - plt.savefig(filename, dpi=SET_DPI) + plt.savefig(filename, dpi=SET_DPI, bbox_inches="tight") closeplot = True if closeplot: diff --git a/nigsp/workflow.py b/nigsp/workflow.py index 43afad6..806d4ab 100644 --- a/nigsp/workflow.py +++ b/nigsp/workflow.py @@ -402,31 +402,55 @@ def nigsp( # Plot original SC and Laplacian LGR.info("Plot laplacian matrix.") - viz.plot_connectivity(scgraph.lapl_mtx, f"{outprefix}laplacian.png") + viz.plot_connectivity( + scgraph.lapl_mtx, f"{outprefix}laplacian.png", title="Laplacian matrix" + ) LGR.info("Plot structural connectivity matrix.") - viz.plot_connectivity(scgraph.mtx, f"{outprefix}sc.png") + viz.plot_connectivity( + scgraph.mtx, + f"{outprefix}sc.png", + title="Structural Connectivity (underlying graph)", + ) # Plot timeseries LGR.info("Plot original timeseries.") - viz.plot_greyplot(scgraph.timeseries, f"{outprefix}greyplot.png") + viz.plot_greyplot( + scgraph.timeseries, f"{outprefix}greyplot.png", title="Original timeseries" + ) for k in scgraph.split_keys: LGR.info(f"Plot {k} timeseries.") - viz.plot_greyplot(scgraph.ts_split[k], f"{outprefix}greyplot_{k}.png") + viz.plot_greyplot( + scgraph.ts_split[k], + f"{outprefix}greyplot_{k}.png", + title=f"Filtered timeseries ({k} filter)", + ) if "dfc" in comp_metric or "fc" in comp_metric: # Plot FC LGR.info("Plot original functional connectivity matrix.") - viz.plot_connectivity(scgraph.fc, f"{outprefix}fc.png") + viz.plot_connectivity( + scgraph.fc, + f"{outprefix}fc.png", + title="Original Functional Connectivity", + ) for k in scgraph.split_keys: LGR.info(f"Plot {k} functional connectivity matrix.") - viz.plot_connectivity(scgraph.fc_split[k], f"{outprefix}fc_{k}.png") + viz.plot_connectivity( + scgraph.fc_split[k], + f"{outprefix}fc_{k}.png", + title=f"Filtered FC ({k} filter)", + ) if "sdi" in comp_metric or "gsdi" in comp_metric: if atlasname is not None: LGR.info(f"Plot {metric_name} markerplot.") if img is not None: - blocks.plot_metric(scgraph, outprefix, img) + blocks.plot_metric( + scgraph, outprefix, atlas=img, title=f"{metric_name.upper()}" + ) elif atlas is not None: - blocks.plot_metric(scgraph, outprefix, atlas) + blocks.plot_metric( + scgraph, outprefix, atlas=atlas, title=f"{metric_name.upper()}" + ) except ImportError: LGR.warning( @@ -455,9 +479,21 @@ def nigsp( if atlasname is not None: LGR.info(f"Plot {metric_name} markerplot.") if img is not None: - blocks.plot_metric(scgraph, outprefix, atlas=img, thr=0) + blocks.plot_metric( + scgraph, + outprefix, + atlas=img, + title=f"Statistically significant {metric_name.upper()} (Structurally {surr_type} surrogates)", + thr=0, + ) elif atlas is not None: - blocks.plot_metric(scgraph, outprefix, atlas=atlas, thr=0) + blocks.plot_metric( + scgraph, + outprefix, + atlas=atlas, + title=f"Statistically significant {metric_name.upper()} (Structurally {surr_type} surrogates)", + thr=0, + ) except ImportError: pass