Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
python-version: "3.12"

- name: Build source and wheel distributions
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.8"
python-version: "3.12"

- name: Build source and wheel distributions
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
36 changes: 27 additions & 9 deletions bootplot/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,38 @@ def __init__(self,
f: callable,
data: Union[np.ndarray, pd.DataFrame],
m: int,
output_size_px: Tuple[int, int]):
output_size_px: Tuple[int, int],
single_sample: bool):
self.output_size_px = output_size_px
self.f = f
self.data = data
self.m = m
self.single_sample = single_sample

@abstractmethod
def create_figure(self):
raise NotImplemented

def plot(self):
indices = np.random.randint(low=0, high=len(self.data), size=len(self.data))
size = 1 if self.single_sample else len(self.data)
indices = np.random.randint(low=0, high=len(self.data), size=size)
if isinstance(self.data, pd.DataFrame):
return self.f(self.data.iloc[indices], self.data, *self.plot_args)
elif isinstance(self.data, np.ndarray):
return self.f(self.data[indices], self.data, *self.plot_args)

def sample_all_indices(self):
size = 1 if self.single_sample else len(self.data)
return np.random.randint(0, len(self.data), size=(self.m, size))


def plot_from_indices(self, indices):
if isinstance(self.data, pd.DataFrame):
subset = self.data.iloc[indices]
else:
subset = self.data[indices]
return self.f(subset, self.data, *self.plot_args)

@abstractmethod
def plot_to_array(self) -> np.ndarray:
raise NotImplemented
Expand All @@ -50,8 +65,8 @@ def plot_args(self):


class Basic(Backend):
def __init__(self, f: callable, data: Union[np.ndarray, pd.DataFrame], m: int, output_size_px: Tuple[int, int]):
super().__init__(f, data, m, output_size_px)
def __init__(self, f: callable, data: Union[np.ndarray, pd.DataFrame], m: int, output_size_px: Tuple[int, int], single_sample: bool):
super().__init__(f, data, m, output_size_px, single_sample)
self.cached_image = None

def plot(self):
Expand Down Expand Up @@ -79,10 +94,11 @@ def __init__(self,
f: callable,
data: Union[np.ndarray, pd.DataFrame],
m: int,
output_size_px: Tuple[int, int] = (512, 512)):
output_size_px: Tuple[int, int] = (512, 512),
single_sample: bool = False):
self.fig = None
self.ax = None
super().__init__(f, data, m, output_size_px)
super().__init__(f, data, m, output_size_px, single_sample)

def create_figure(self):
self.fig, self.ax = bootplot.backend.matplotlib.create_figure(self.output_size_px)
Expand All @@ -106,8 +122,9 @@ def __init__(self,
f: callable,
data: Union[np.ndarray, pd.DataFrame],
m: int,
output_size_px: Tuple[int, int] = (512, 512)):
super().__init__(f, data, m, output_size_px)
output_size_px: Tuple[int, int] = (512, 512),
single_sample: bool = False):
super().__init__(f, data, m, output_size_px, single_sample)

def create_figure(self):
raise NotImplemented
Expand All @@ -131,7 +148,8 @@ def __init__(self,
f: callable,
data: Union[np.ndarray, pd.DataFrame],
m: int,
output_size_px: Tuple[int, int] = (512, 512)):
output_size_px: Tuple[int, int] = (512, 512),
single_sample: bool = False):
super().__init__(f, data, m, output_size_px)

def create_figure(self):
Expand Down
160 changes: 98 additions & 62 deletions bootplot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,88 @@
import numpy as np
import imageio
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from tqdm import tqdm
from PIL import Image, ImageFilter
from scipy.stats import beta

from bootplot.backend.base import Backend, create_backend
from bootplot.sorting import sort_images
from collections import Counter

import jax.numpy as jnp
from jax import jit, vmap, device_get
from jax.scipy.special import betainc

def plot(plot_function: callable,
data: Union[np.ndarray, pd.DataFrame],
indices: np.ndarray,
backend: Backend,
**kwargs):
if isinstance(data, pd.DataFrame):
plot_function(data.iloc[indices], data, *backend.plot_args, kwargs)
else:
plot_function(data[indices], data, *backend.plot_args, **kwargs)

def symmetric_transformation_new(x,
k,
threshold):
y = beta.cdf(x, k, k)
return (1-2*threshold) * y + threshold

def adjust_relative_frequencies_opt(relative_frequencies,
k,
threshold):
dominant_color = max(relative_frequencies, key=relative_frequencies.get)
transformed_dominant = symmetric_transformation_new(relative_frequencies[dominant_color], k, threshold)
sum_other = 1-relative_frequencies[dominant_color]
transformed_other = 1-transformed_dominant
return {
color: transformed_other * rel_freq / sum_other if color != dominant_color else
transformed_dominant
for color, rel_freq in relative_frequencies.items()
}

def merge_images(images: np.ndarray,
k: int,
threshold: int) -> np.ndarray:
num_images, rows, cols, _ = images.shape
new_image = np.zeros((rows, cols, 3), dtype=np.uint8)

# Iterate over each pixel location
for i in range(rows):
for j in range(cols):
# Extract the colors at the current pixel location across all images
pixel_colors = [tuple(images[img, i, j]) for img in range(num_images)]
# Count the occurrence of each color in this list of colors
color_counts = Counter(pixel_colors)
percentages_old = {color: count / sum(color_counts.values()) for color, count in color_counts.items()}
if len(percentages_old) > 1:
percentages = adjust_relative_frequencies_opt(percentages_old, k, threshold)
new_color = np.sum([np.array(c) * p for c, p in percentages.items()], axis=0)
new_color = np.clip(new_color, 0, 255).astype(np.uint8)
new_image[i, j] = new_color
else:
new_image[i,j] = list(percentages_old.keys())[0]
return new_image

def symmetric_transformation_new(x: float,
k: float,
threshold: float) -> float:
y = betainc(k, k, x)
return (1 - 2 * threshold) * y + threshold

def adjust_freqs(freqs: jnp.ndarray,
k: float,
threshold: float) -> jnp.ndarray:
dom_idx = jnp.argmax(freqs)
dom = freqs[dom_idx]

t_dom = symmetric_transformation_new(dom, k, threshold)
sum_other = 1.0 - dom
scale = (1.0 - t_dom) / sum_other

out = freqs * scale
return out.at[dom_idx].set(t_dom)


def process_pixel(pixel_stack: jnp.ndarray,
k: float,
threshold: float) -> jnp.ndarray:
mn = pixel_stack.shape[0]

r = pixel_stack[:, 0].astype(jnp.int32)
g = pixel_stack[:, 1].astype(jnp.int32)
b = pixel_stack[:, 2].astype(jnp.int32)

idx = (r << 16) + (g << 8) + b

uniq, counts = jnp.unique(idx, size=mn, fill_value=0, return_counts=True)

n_unique = jnp.sum(counts > 0)

ur = ((uniq >> 16) & 255).astype(jnp.float32)
ug = ((uniq >> 8) & 255).astype(jnp.float32)
ub = (uniq & 255).astype(jnp.float32)

colors = jnp.stack([ur, ug, ub], axis=1)

freqs = counts.astype(jnp.float32) / mn

only_one = (n_unique == 1)
one_color = colors[0].astype(jnp.uint8)

freqs_adj = adjust_freqs(freqs, k, threshold)

rgb = jnp.sum(colors * freqs_adj[:, None], axis=0)
rgb = jnp.clip(rgb, 0, 255).astype(jnp.uint8)

return jnp.where(only_one, one_color, rgb)



@jit
def merge_images(images: np.ndarray,
k: float,
threshold: float) -> jnp.ndarray:
mn, rows, cols, _ = images.shape

pixels = images.transpose(1, 2, 0, 3)

#each of the rows * cols elements is a list of RGB pixels from all images at the same location:
pixels = pixels.reshape(rows * cols, mn, 3)

fused = vmap(process_pixel, in_axes=(0, None, None))(pixels, k, threshold)
return fused.reshape(rows, cols, 3)


def merge_images_original(images: np.ndarray) -> np.ndarray:
Expand All @@ -84,6 +104,7 @@ def merge_images_original(images: np.ndarray) -> np.ndarray:
return merged



def decay_images(images: np.ndarray,
m: int,
decay_length: int) -> np.ndarray:
Expand All @@ -107,12 +128,16 @@ def decay_images(images: np.ndarray,
return decayed_images





def bootplot(f: callable,
data: Union[np.ndarray, pd.DataFrame],
m: int = 100,
k: int = 2.5,
threshold: int = 0.3,
output_size_px: Tuple[int, int] = (512, 512),
single_sample: bool = False,
output_image_path: Union[str, Path] = None,
transformation: bool = True,
output_animation_path: Union[str, Path] = None,
Expand Down Expand Up @@ -147,9 +172,12 @@ def bootplot(f: callable,
:param threshold: input transformation parameter. Controls the codomain of the transformation. It lies between 0 and 0.5. Default: ``0,3``.
:type threshold: int

:param output_size_px: output size (height, width) in pixels. Default: ``(512, 512)``.
:param output_size_px: output size (width, heigth) in pixels. Default: ``(512, 512)``.
:type output_size_px: tuple[int, int]

:param single_sample: if true data_subset consists of a single sample. Default: ``False``.
:type single_sample: bool

:param output_image_path: path where the image should be stored. The image format is inferred from the filename
extension. If None, the image is not stored. Default: ``None``.
:type output_image_path: str or pathlib.Path
Expand Down Expand Up @@ -223,28 +251,36 @@ def bootplot(f: callable,
>>> image.shape
(512, 512, 3)
"""


if isinstance(backend, str):
backend = create_backend(backend, f, data, m, output_size_px=output_size_px)
backend_class = create_backend(backend, f, data, m, output_size_px=output_size_px, single_sample=single_sample)

backend.create_figure()
backend_class.create_figure()
images = []
for _ in tqdm(range(m), desc='Generating plots', disable=not verbose):
backend.plot()
image = backend.plot_to_array()
backend_class.plot()
image = backend_class.plot_to_array()
images.append(image)
backend.clear_figure()
backend.close_figure()
backend_class.clear_figure()
backend_class.close_figure()
images = np.stack(images)


if transformation:
merged_image = merge_images(images[..., :3], k, threshold)
merged_image = np.array(merge_images(images[..., :3], k, threshold))

else:
merged_image = merge_images_original(images[..., :3])

if output_image_path is not None:
if verbose:
print(f'> Saving bootstrapped image to {output_image_path}')
Image.fromarray(merged_image).save(output_image_path)
if isinstance(backend, str) and backend.lower() == "matplotlib":
dpi = plt.rcParams['figure.dpi']
Image.fromarray(merged_image).save(output_image_path, dpi=(dpi, dpi))
else:
Image.fromarray(merged_image).save(output_image_path)
if output_animation_path is not None:
sort_kwargs = dict() if sort_kwargs is None else sort_kwargs
order = sort_images(images, sort_type, verbose=verbose, **sort_kwargs)
Expand Down
13 changes: 7 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
numpy>=1.3,<2
imageio~=2.9.0
imageio-ffmpeg==0.4.7
numpy>=1.3
imageio>=2.9.0
imageio-ffmpeg>=0.4.7
matplotlib>=3
tqdm~=4.64.0
tqdm>=4.64.0
pillow>=8
scipy>=1.5
scikit-image>=0.17
networkx>=2.7.1
scikit-learn>=0.24
opencv-python~=4.5.5
pandas~=1.4.3
opencv-python>=4.5.5
pandas>=1.4.3
jax>=0.8.1