Skip to content

Commit ed66bd8

Browse files
Merge pull request #8 from davidnabergoj/updates
2 parents 80227a2 + 7353d0c commit ed66bd8

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

bootplot/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.12"
1+
__version__ = "0.0.13"

bootplot/base.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@ def plot(plot_function: callable,
2525
else:
2626
plot_function(data[indices], data, *backend.plot_args, **kwargs)
2727

28-
def symmetric_transformation_new(x, k=0.5, threshold = 0.3):
28+
def symmetric_transformation_new(x,
29+
k,
30+
threshold):
2931
y = beta.cdf(x, k, k)
3032
return (1-2*threshold) * y + threshold
3133

32-
def adjust_relative_frequencies_opt(relative_frequencies):
34+
def adjust_relative_frequencies_opt(relative_frequencies,
35+
k,
36+
threshold):
3337
dominant_color = max(relative_frequencies, key=relative_frequencies.get)
34-
transformed_dominant = symmetric_transformation_new(relative_frequencies[dominant_color])
38+
transformed_dominant = symmetric_transformation_new(relative_frequencies[dominant_color], k, threshold)
3539
sum_other = 1-relative_frequencies[dominant_color]
3640
transformed_other = 1-transformed_dominant
3741
return {
@@ -40,7 +44,9 @@ def adjust_relative_frequencies_opt(relative_frequencies):
4044
for color, rel_freq in relative_frequencies.items()
4145
}
4246

43-
def merge_images(images: np.ndarray) -> np.ndarray:
47+
def merge_images(images: np.ndarray,
48+
k: int,
49+
threshold: int) -> np.ndarray:
4450
num_images, rows, cols, _ = images.shape
4551
new_image = np.zeros((rows, cols, 3), dtype=np.uint8)
4652

@@ -53,7 +59,7 @@ def merge_images(images: np.ndarray) -> np.ndarray:
5359
color_counts = Counter(pixel_colors)
5460
percentages_old = {color: count / sum(color_counts.values()) for color, count in color_counts.items()}
5561
if len(percentages_old) > 1:
56-
percentages = adjust_relative_frequencies_opt(percentages_old)
62+
percentages = adjust_relative_frequencies_opt(percentages_old, k, threshold)
5763
new_color = np.sum([np.array(c) * p for c, p in percentages.items()], axis=0)
5864
new_color = np.clip(new_color, 0, 255).astype(np.uint8)
5965
new_image[i, j] = new_color
@@ -62,6 +68,21 @@ def merge_images(images: np.ndarray) -> np.ndarray:
6268
return new_image
6369

6470

71+
def merge_images_original(images: np.ndarray) -> np.ndarray:
72+
"""
73+
Merge images into a static image (averaged image) without transformation.
74+
The shape of images is (batch_size, width, height, channels).
75+
This operation overwrites input images.
76+
77+
:param images: images corresponding to different bootstrap resamples.
78+
:param images: images corresponding to different bootstrap samples.
79+
:return: merged image.
80+
"""
81+
images = images.astype(np.float32) / 255 # Cast to float
82+
merged = np.mean(images, axis=0)
83+
merged = (merged * 255).astype(np.uint8)
84+
return merged
85+
6586

6687
def decay_images(images: np.ndarray,
6788
m: int,
@@ -89,8 +110,11 @@ def decay_images(images: np.ndarray,
89110
def bootplot(f: callable,
90111
data: Union[np.ndarray, pd.DataFrame],
91112
m: int = 100,
113+
k: int = 2.5,
114+
threshold: int = 0.3,
92115
output_size_px: Tuple[int, int] = (512, 512),
93116
output_image_path: Union[str, Path] = None,
117+
transformation: bool = True,
94118
output_animation_path: Union[str, Path] = None,
95119
sort_type: str = 'tsp',
96120
sort_kwargs: dict = None,
@@ -117,13 +141,22 @@ def bootplot(f: callable,
117141
:param m: number of boostrap resamples. Default: ``100``.
118142
:type m: int
119143
144+
:param k: input beta cdf transformation parameter. Controls the shape Default: ``2.5``.
145+
:type k: int
146+
147+
:param threshold: input transformation parameter. Controls the codomain of the transformation. It lies between 0 and 1. Default: ``0,3``.
148+
:type threshold: int
149+
120150
:param output_size_px: output size (height, width) in pixels. Default: ``(512, 512)``.
121151
:type output_size_px: tuple[int, int]
122152
123153
:param output_image_path: path where the image should be stored. The image format is inferred from the filename
124154
extension. If None, the image is not stored. Default: ``None``.
125155
:type output_image_path: str or pathlib.Path
126156
157+
:param transformation: if True transformation is applied, else images are just averaged. Default: ``True``.
158+
:type transformation: bool
159+
127160
:param output_animation_path: path where the animation should be stored. The animation format is inferred from the
128161
filename extension. If None, the animation is not created. Default: ``None``.
129162
:type output_animation_path: str or pathlib.Path
@@ -203,7 +236,10 @@ def bootplot(f: callable,
203236
backend.close_figure()
204237
images = np.stack(images)
205238

206-
merged_image = merge_images(images[..., :3])
239+
if transformation:
240+
merged_image = merge_images(images[..., :3], k, threshold)
241+
else:
242+
merged_image = merge_images_original(images[..., :3])
207243

208244
if output_image_path is not None:
209245
if verbose:

test/test_bootplot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_bmp():
7979
def test_merge():
8080
np.random.seed(0)
8181
images = np.random.randint(low=0, high=256, size=(25, 100, 100, 3))
82-
merged = merge_images(images)
82+
merged = merge_images(images, k=2.5, threshold=0.3)
8383
assert merged.shape == (100, 100, 3)
8484
assert np.min(merged) >= 0
8585
assert np.max(merged) <= 255

0 commit comments

Comments
 (0)