Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/plotter_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

~CanvasWidget.active_artist
~CanvasWidget.active_selector
~CanvasWidget.show_color_overlay

.. rubric:: Methods Summary

Expand All @@ -21,19 +22,20 @@
~CanvasWidget.add_selector
~CanvasWidget.remove_selector
~CanvasWidget.on_enable_selector
~CanvasWidget.hide_color_overlay

.. rubric:: Signals Summary

.. autosummary::

~CanvasWidget.artist_changed_signal
~CanvasWidget.selector_changed_signal
~CanvasWidget.show_overlay_signal

.. rubric:: Properties Documentation

.. autoattribute:: active_artist
.. autoattribute:: active_selector
.. autoattribute:: show_color_overlay

.. rubric:: Methods Documentation

Expand All @@ -42,10 +44,10 @@
.. automethod:: add_selector
.. automethod:: remove_selector
.. automethod:: on_enable_selector
.. automethod:: hide_color_overlay

.. rubric:: Signals Documentation

.. autoattribute:: artist_changed_signal
.. autoattribute:: selector_changed_signal
.. autoattribute:: show_overlay_signal
```
2 changes: 1 addition & 1 deletion src/biaplotter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.0"
__version__ = "0.3.0"
from .artists import Histogram2D, Scatter
from .colormap import BiaColormap
from .plotter import CanvasWidget
Expand Down
20 changes: 11 additions & 9 deletions src/biaplotter/_tests/test_artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def on_color_indices_changed(color_indices):
assert scatter.color_indices.shape == (size,)

# Test scatter colors
colors = scatter._mpl_artists['scatter'].get_facecolors()
colors = scatter._mpl_artists["scatter"].get_facecolors()
assert np.all(colors[0] == scatter.overlay_colormap(0))
assert np.all(colors[50] == scatter.overlay_colormap(2))

Expand All @@ -69,27 +69,27 @@ def on_color_indices_changed(color_indices):
# Test size property
scatter.size = 5.0
assert scatter.size == 5.0
sizes = scatter._mpl_artists['scatter'].get_sizes()
sizes = scatter._mpl_artists["scatter"].get_sizes()
assert np.all(sizes == 5.0)

scatter.size = np.linspace(1, 10, size)
assert np.all(scatter.size == np.linspace(1, 10, size))
sizes = scatter._mpl_artists['scatter'].get_sizes()
sizes = scatter._mpl_artists["scatter"].get_sizes()
assert np.all(sizes == np.linspace(1, 10, size))

# Test size reset when new data is set
scatter.data = np.random.rand(size // 2, 2)
assert np.all(scatter.size == 50.0) # that's the default
sizes = scatter._mpl_artists['scatter'].get_sizes()
sizes = scatter._mpl_artists["scatter"].get_sizes()
assert np.all(sizes == 50.0)

# test alpha
scatter.alpha = 0.5
assert np.all(scatter._mpl_artists['scatter'].get_alpha() == 0.5)
assert np.all(scatter._mpl_artists["scatter"].get_alpha() == 0.5)

# test alpha reset when new data is set
scatter.data = np.random.rand(size, 2)
assert np.all(scatter._mpl_artists['scatter'].get_alpha() == 1.0)
assert np.all(scatter._mpl_artists["scatter"].get_alpha() == 1.0)

# Test changing overlay_colormap
assert scatter.overlay_colormap.name == "cat10_modified"
Expand All @@ -98,7 +98,7 @@ def on_color_indices_changed(color_indices):

# Test scatter color indices after continuous overlay_colormap
scatter.color_indices = np.linspace(0, 1, size)
colors = scatter._mpl_artists['scatter'].get_facecolors()
colors = scatter._mpl_artists["scatter"].get_facecolors()
assert np.all(colors[0] == plt.cm.viridis(0))

# Test scatter color_normalization_method
Expand Down Expand Up @@ -181,7 +181,9 @@ def on_color_indices_changed(color_indices):
assert histogram.cmin == 0

# Test overlay colors
overlay_array = histogram._mpl_artists['overlay_histogram_image'].get_array()
overlay_array = histogram._mpl_artists[
"overlay_histogram_image"
].get_array()
assert overlay_array.shape == (bins, bins, 4)
# indices where overlay_array is not zero
indices = np.where(overlay_array[..., -1] != 0)
Expand Down Expand Up @@ -229,7 +231,7 @@ def on_color_indices_changed(color_indices):

# Don't draw overlay histogram if color_indices are nan
histogram.color_indices = np.nan
assert 'overlay_histogram_image' not in histogram._mpl_artists.keys()
assert "overlay_histogram_image" not in histogram._mpl_artists.keys()


# Test calculate_statistic_histogram_method for different statistics
Expand Down
8 changes: 4 additions & 4 deletions src/biaplotter/_tests/test_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def test_disable_all_selectors(canvas_widget):
assert selector._selector is None


def test_hide_color_overlay(canvas_widget):
"""Test the hide_color_overlay method."""
canvas_widget.hide_color_overlay(True)
def test_show_color_overlay(canvas_widget):
"""Test the show_color_overlay method."""
canvas_widget.show_color_overlay = False
assert not canvas_widget.active_artist.overlay_visible

canvas_widget.hide_color_overlay(False)
canvas_widget.show_color_overlay = True
assert canvas_widget.active_artist.overlay_visible


Expand Down
91 changes: 49 additions & 42 deletions src/biaplotter/artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
SymLogNorm)
from nap_plot_tools.cmap import (cat10_mod_cmap,
cat10_mod_cmap_first_transparent)
from scipy.stats import binned_statistic_2d

from biaplotter.colormap import BiaColormap
from scipy.stats import binned_statistic_2d

from .artists_base import Artist


Expand Down Expand Up @@ -52,8 +53,6 @@
>>> plt.show()
"""



def __init__(
self,
ax: plt.Axes = None,
Expand All @@ -74,16 +73,17 @@
def _refresh(self, force_redraw: bool = True):
"""Creates the scatter plot with the data and default properties."""

if force_redraw or self._mpl_artists['scatter'] is None:
if force_redraw or self._mpl_artists["scatter"] is None:
self._remove_artists()
# Create a new scatter plot with the updated data
self._mpl_artists['scatter'] = self.ax.scatter(
self._data[:, 0], self._data[:, 1])
self._mpl_artists["scatter"] = self.ax.scatter(
self._data[:, 0], self._data[:, 1]
)
self.size = 50 # Default size
self.alpha = 1 # Default alpha
self.color_indices = 0
else:
self._mpl_artists['scatter'].set_offsets(
self._mpl_artists["scatter"].set_offsets(
self._data
) # somehow resets the size and alpha
self.color_indices = self._color_indices
Expand All @@ -95,22 +95,21 @@
Add a color to the drawn scatter points
"""
rgba_colors = self.color_indices_to_rgba(indices)
self._mpl_artists['scatter'].set_facecolor(rgba_colors)
self._mpl_artists['scatter'].set_edgecolor("white")
self._mpl_artists["scatter"].set_facecolor(rgba_colors)
self._mpl_artists["scatter"].set_edgecolor("white")

return rgba_colors

def color_indices_to_rgba(
self,
indices: np.ndarray,
is_overlay: bool = True) -> np.ndarray:
self, indices: np.ndarray, is_overlay: bool = True
) -> np.ndarray:
"""
Convert color indices to RGBA colors using the colormap.
"""
norm = self._get_normalization(indices)
colormap = self.overlay_colormap.cmap

rgba = colormap(norm(self._color_indices))
rgba = colormap(norm(indices))
return rgba

def _get_normalization(self, values: np.ndarray) -> Normalize:
Expand All @@ -127,7 +126,8 @@
}

normalization_func = norm_dispatch.get(
self._color_normalization_method)
self._color_normalization_method
)
if normalization_func is None:
raise ValueError(
f"Unknown color normalization method: {self._color_normalization_method}.\n"
Expand Down Expand Up @@ -165,7 +165,11 @@
def overlay_visible(self, value: bool):
"""Sets the visibility of the overlay colormap."""
self._overlay_visible = value
self._colorize(self._color_indices)
if value:
self._colorize(self._color_indices)

Check warning on line 169 in src/biaplotter/artists.py

View check run for this annotation

Codecov / codecov/patch

src/biaplotter/artists.py#L168-L169

Added lines #L168 - L169 were not covered by tests
else:
self._colorize(np.zeros_like(self._color_indices))
self.draw()

Check warning on line 172 in src/biaplotter/artists.py

View check run for this annotation

Codecov / codecov/patch

src/biaplotter/artists.py#L171-L172

Added lines #L171 - L172 were not covered by tests

@property
def color_normalization_method(self) -> str:
Expand Down Expand Up @@ -193,7 +197,7 @@
alpha : float
alpha value of the scatter plot.
"""
return self._mpl_artists['scatter'].get_alpha()
return self._mpl_artists["scatter"].get_alpha()

Check warning on line 200 in src/biaplotter/artists.py

View check run for this annotation

Codecov / codecov/patch

src/biaplotter/artists.py#L200

Added line #L200 was not covered by tests

@alpha.setter
def alpha(self, value: Union[float, np.ndarray]):
Expand All @@ -202,8 +206,8 @@

if np.isscalar(value):
value = np.ones(len(self._data)) * value
if 'scatter' in self._mpl_artists.keys():
self._mpl_artists['scatter'].set_alpha(value)
if "scatter" in self._mpl_artists.keys():
self._mpl_artists["scatter"].set_alpha(value)
self.draw()

@property
Expand All @@ -223,8 +227,8 @@
def size(self, value: Union[float, np.ndarray]):
"""Sets the size of the points in the scatter plot."""
self._size = value
if 'scatter' in self._mpl_artists.keys():
self._mpl_artists['scatter'].set_sizes(
if "scatter" in self._mpl_artists.keys():
self._mpl_artists["scatter"].set_sizes(
np.full(len(self._data), value)
if np.isscalar(value)
else value
Expand Down Expand Up @@ -305,14 +309,14 @@
self._histogram_rgba = self.color_indices_to_rgba(
counts.T, is_overlay=False
)
self._mpl_artists['histogram_image'] = self.ax.imshow(
self._mpl_artists["histogram_image"] = self.ax.imshow(
self._histogram_rgba,
extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
origin="lower",
zorder=1,
interpolation=self._histogram_interpolation,
alpha=1,
aspect='auto'
aspect="auto",
)

if force_redraw:
Expand All @@ -328,33 +332,35 @@
_, x_edges, y_edges = self._histogram
# Assign median values to the bins (fill with NaNs if no data in the bin)
statistic_histogram, _, _, _ = binned_statistic_2d(
x = self._data[:, 0],
y= self._data[:, 1],
x=self._data[:, 0],
y=self._data[:, 1],
values=indices,
statistic=_median_np,
bins=[x_edges, y_edges]
bins=[x_edges, y_edges],
)
if not np.all(np.isnan(statistic_histogram)):
# Draw the overlay
self.overlay_histogram_rgba = self.color_indices_to_rgba(
statistic_histogram.T, is_overlay=True
)
self._mpl_artists['overlay_histogram_image'] = self.ax.imshow(
self._mpl_artists["overlay_histogram_image"] = self.ax.imshow(
self.overlay_histogram_rgba,
extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
origin="lower",
zorder=2,
interpolation=self._overlay_interpolation,
alpha=self._overlay_opacity,
aspect='auto'
aspect="auto",
)

def color_indices_to_rgba(self, indices, is_overlay: bool = True) -> np.ndarray:
def color_indices_to_rgba(
self, indices, is_overlay: bool = True
) -> np.ndarray:
"""
Convert color indices to RGBA colors using the overlay colormap.
"""
norm = self._get_normalization(indices, is_overlay=is_overlay)

if is_overlay:
colormap = self.overlay_colormap.cmap
else:
Expand Down Expand Up @@ -430,7 +436,7 @@
minimum count for the histogram.
"""
return self._cmin

@cmin.setter
def cmin(self, value: int):
"""Sets the minimum count for the histogram."""
Expand Down Expand Up @@ -524,8 +530,8 @@
def overlay_visible(self, value):
"""Sets the visibility of the overlay histogram."""
self._overlay_visible = value
if 'overlay_histogram_image' in self._mpl_artists:
self._mpl_artists['overlay_histogram_image'].set_visible(value)
if "overlay_histogram_image" in self._mpl_artists:
self._mpl_artists["overlay_histogram_image"].set_visible(value)
self.draw()

@property
Expand Down Expand Up @@ -602,9 +608,8 @@
return False

def _get_normalization(
self,
values: np.ndarray,
is_overlay: bool = True) -> Normalize:
self, values: np.ndarray, is_overlay: bool = True
) -> Normalize:
"""
Get the normalization class for the histogram data.

Expand Down Expand Up @@ -634,17 +639,19 @@
# norm_dispatch is to be indexed like this:
# norm_dispatch[is_categorical, color_normalization_method]
norm_dispatch = {
(True, 'linear'): lambda: self._linear_normalization(values, is_categorical),
(False, 'linear'): lambda: self._linear_normalization(values),
(False, 'log'): lambda: self._log_normalization(values),
(False, 'centered'): lambda: self._centered_normalization(values),
(False, 'symlog'): lambda: self._symlog_normalization(values),
(True, "linear"): lambda: self._linear_normalization(
values, is_categorical
),
(False, "linear"): lambda: self._linear_normalization(values),
(False, "log"): lambda: self._log_normalization(values),
(False, "centered"): lambda: self._centered_normalization(values),
(False, "symlog"): lambda: self._symlog_normalization(values),
}

return norm_dispatch.get((is_categorical, norm_method))()


def _median_np(arr, method='lower') -> float:
def _median_np(arr, method="lower") -> float:
"""Calculate the median of a 1D array.

Parameters
Expand All @@ -661,4 +668,4 @@
"""
if len(arr) == 0:
return np.nan
return np.nanpercentile(arr, 50, method=method)
return np.nanpercentile(arr, 50, method=method)
Loading