Skip to content

Commit 39b1af9

Browse files
authored
updates
1 parent 7666940 commit 39b1af9

File tree

2 files changed

+234
-91
lines changed

2 files changed

+234
-91
lines changed

src/aind_exaspim_image_compression/utils/img_util.py

Lines changed: 141 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
"""
1010

1111
from bm4d import bm4d
12+
from concurrent.futures import ThreadPoolExecutor
13+
from itertools import product
1214
from numcodecs import Blosc
1315
from ome_zarr.writer import write_multiscale
1416
from scipy.ndimage import uniform_filter
@@ -69,36 +71,10 @@ def _read_tiff(img_path, storage_options=None):
6971

7072

7173
def _is_gcs_path(path):
72-
"""
73-
Checks whether image is stored in a GCS bucket.
74-
75-
Parameters
76-
----------
77-
img_path : str
78-
Path to image.
79-
80-
Returns
81-
-------
82-
bool
83-
Indication of whether image is stored in a GCS bucket.
84-
"""
8574
return path.startswith("gs://")
8675

8776

8877
def _is_s3_path(path):
89-
"""
90-
Checks whether image is stored in an S3 bucket.
91-
92-
Parameters
93-
----------
94-
img_path : str
95-
Path to image.
96-
97-
Returns
98-
-------
99-
bool
100-
Indication of whether image is stored in a S3 bucket.
101-
"""
10278
return path.startswith("s3://")
10379

10480

@@ -271,9 +247,9 @@ def local_to_physical(local_voxel, offset, multiscale):
271247
return to_physical(global_voxel, multiscale)
272248

273249

274-
# --- Custom Classes ---
250+
# --- Compression utils ---
275251
class BM4D:
276-
def __init__(self, sigma=10):
252+
def __init__(self, sigma=100):
277253
self.sigma = sigma
278254

279255
def __call__(self, noise):
@@ -283,6 +259,112 @@ def __call__(self, noise):
283259
return (noise - mn) / mx, (denoised - mn) / mx, (mn, mx)
284260

285261

262+
def compute_cratio(img, codec, chunk_shape=(64, 64, 64)):
263+
"""
264+
Computes a Zarr-style chunked compression ratio for a given image.
265+
266+
Parameters
267+
----------
268+
img : np.ndarray
269+
Image to compute compression ratio of.
270+
codec : blosc.Blosc
271+
Blosc codec used to compress each chunk.
272+
chunk_shape : Tuple[int]
273+
Shape of chunks Zarr would use. Default is (64, 64, 64).
274+
275+
Returns
276+
-------
277+
float
278+
Compression ratio = total uncompressed size / total compressed size.
279+
"""
280+
img = np.ascontiguousarray(img, dtype=np.uint16)
281+
total_compressed_size = 0
282+
total_uncompressed_size = 0
283+
284+
z = [range(0, s, c) for s, c in zip(img.shape, chunk_shape)]
285+
for z0 in z[0]:
286+
for z1 in z[1]:
287+
for z2 in z[2] if len(z) > 2 else [0]:
288+
slice_ = img[
289+
z0: z0 + chunk_shape[0],
290+
z1: z1 + chunk_shape[1],
291+
z2: z2 + chunk_shape[2] if len(z) > 2 else slice(None),
292+
]
293+
chunk = np.ascontiguousarray(slice_)
294+
compressed = codec.encode(chunk)
295+
total_compressed_size += len(compressed)
296+
total_uncompressed_size += chunk.nbytes
297+
return round(total_uncompressed_size / total_compressed_size, 2)
298+
299+
300+
def compute_cratio_jpegxl(img, codec, chunk_shape=(128, 128, 64), max_workers=32):
301+
img = np.ascontiguousarray(img)
302+
shape = img.shape
303+
ndim = img.ndim
304+
305+
# Generate chunk start indices
306+
chunk_ranges = [range(0, s, c) for s, c in zip(shape, chunk_shape)]
307+
chunk_coords = list(product(*chunk_ranges))
308+
309+
def compress_patch(idx):
310+
slices = tuple(slice(i, min(i + c, s)) for i, c, s in zip(idx, chunk_shape, shape))
311+
patch = img[slices]
312+
compressed_size = 0
313+
for k in range(patch.shape[-1]):
314+
slice2d = np.ascontiguousarray(patch[..., k])
315+
encoded = codec.encode(slice2d)
316+
compressed_size += len(encoded)
317+
return patch.nbytes, compressed_size
318+
319+
total_uncompressed = 0
320+
total_compressed = 0
321+
with ThreadPoolExecutor(max_workers=max_workers) as pool:
322+
for ubytes, cbytes in pool.map(compress_patch, chunk_coords):
323+
total_uncompressed += ubytes
324+
total_compressed += cbytes
325+
326+
return round(total_uncompressed / total_compressed, 2)
327+
328+
329+
def compress_and_decompress_jpeg(img, codec, chunk_shape=(128, 128, 64), max_workers=32):
330+
img = np.ascontiguousarray(img)
331+
shape = img.shape
332+
333+
chunk_ranges = [range(0, s, c) for s, c in zip(shape, chunk_shape)]
334+
chunk_coords = list(product(*chunk_ranges))
335+
336+
reconstructed = np.empty_like(img)
337+
338+
def process_patch(idx):
339+
slices = tuple(slice(i, min(i + c, s)) for i, c, s in zip(idx, chunk_shape, shape))
340+
patch = img[slices]
341+
342+
compressed_size = 0
343+
decompressed_slices = []
344+
for k in range(patch.shape[-1]):
345+
slice2d = np.ascontiguousarray(patch[..., k])
346+
encoded = codec.encode(slice2d)
347+
compressed_size += len(encoded)
348+
349+
decoded = codec.decode(encoded)
350+
decompressed_slices.append(decoded)
351+
352+
decompressed_patch = np.stack(decompressed_slices, axis=-1)
353+
return slices, patch.nbytes, compressed_size, decompressed_patch
354+
355+
total_uncompressed = 0
356+
total_compressed = 0
357+
358+
with ThreadPoolExecutor(max_workers=max_workers) as pool:
359+
for slices, ubytes, cbytes, decompressed_patch in pool.map(process_patch, chunk_coords):
360+
reconstructed[slices] = decompressed_patch
361+
total_uncompressed += ubytes
362+
total_compressed += cbytes
363+
364+
cratio = round(total_uncompressed / total_compressed, 2)
365+
return reconstructed, cratio
366+
367+
286368
# --- Visualizations ---
287369
def plot_mips(img, output_path=None, vmax=None):
288370
"""
@@ -419,33 +501,39 @@ def convert_tiff_ome_zarr(
419501
-------
420502
None
421503
"""
422-
# Open image
423-
im = tifffile.imread(in_path)
424-
while im.ndim < 5:
425-
im = im[np.newaxis, ...]
426-
427-
# Initializations
428-
pyramid = multiscale(im, windowed_mode, scale_factors=[1, 1, 2, 2, 2])[
429-
:n_levels
430-
]
504+
img = tifffile.imread(in_path)
505+
write_ome_zarr(img, out_path, chunks, compressor, voxel, n_levels)
506+
507+
508+
def write_ome_zarr(
509+
img,
510+
out_path,
511+
chunks: tuple = (1, 1, 64, 128, 128),
512+
compressor: Any = Blosc(cname="zstd", clevel=5, shuffle=Blosc.SHUFFLE),
513+
voxel_size: tuple = (748, 748, 1000),
514+
n_levels: int = 3,
515+
):
516+
# Ensure 5D image (T, C, Z, Y, X)
517+
while img.ndim < 5:
518+
img = img[np.newaxis, ...]
519+
520+
# Generate multiscale pyramid
521+
pyramid = multiscale(img, windowed_mode, scale_factors=[1, 1, 2, 2, 2])[:n_levels]
431522
pyramid = [level.data for level in pyramid]
432-
z = zarr.open(
433-
store=zarr.DirectoryStore(out_path, dimension_separator="/"), mode="w"
434-
)
435-
voxel_size = np.array([1, 1] + list(reversed(voxel_size)))
436-
scales = [
437-
np.concatenate((voxel_size[:2], voxel_size[2:] * 2**i))
438-
for i in range(n_levels)
439-
]
440-
coordinate_transformations = [
441-
[{"type": "scale", "scale": scale.tolist()}] for scale in scales
442-
]
443-
storage_options = {"compressor": compressor}
444-
445-
# Write image
523+
524+
# Prepare Zarr store
525+
store = zarr.DirectoryStore(out_path, dimension_separator="/")
526+
zgroup = zarr.open(store=store, mode="w")
527+
528+
# Voxel size scaling for each level
529+
base_scale = np.array([1, 1, *reversed(voxel_size)])
530+
scales = [base_scale[:2].tolist() + (base_scale[2:] * 2**i).tolist() for i in range(n_levels)]
531+
coordinate_transformations = [[{"type": "scale", "scale": s}] for s in scales]
532+
533+
# Write to OME-Zarr
446534
write_multiscale(
447535
pyramid=pyramid,
448-
group=z,
536+
group=zgroup,
449537
chunks=chunks,
450538
axes=[
451539
{"name": "t", "type": "time", "unit": "millisecond"},
@@ -455,48 +543,10 @@ def convert_tiff_ome_zarr(
455543
{"name": "x", "type": "space", "unit": "micrometer"},
456544
],
457545
coordinate_transformations=coordinate_transformations,
458-
storage_options=storage_options,
546+
storage_options={"compressor": compressor},
459547
)
460548

461549

462-
def compute_cratio(img, codec, chunk_shape=(64, 64, 64)):
463-
"""
464-
Computes a Zarr-style chunked compression ratio for a given image.
465-
466-
Parameters
467-
----------
468-
img : np.ndarray
469-
Image to compute compression ratio of.
470-
codec : blosc.Blosc
471-
Blosc codec used to compress each chunk.
472-
chunk_shape : Tuple[int]
473-
Shape of chunks Zarr would use. Default is (64, 64, 64).
474-
475-
Returns
476-
-------
477-
float
478-
Compression ratio = total uncompressed size / total compressed size.
479-
"""
480-
img = np.ascontiguousarray(img, dtype=np.uint16)
481-
total_compressed_size = 0
482-
total_uncompressed_size = 0
483-
484-
z = [range(0, s, c) for s, c in zip(img.shape, chunk_shape)]
485-
for z0 in z[0]:
486-
for z1 in z[1]:
487-
for z2 in z[2] if len(z) > 2 else [0]:
488-
slice_ = img[
489-
z0: z0 + chunk_shape[0],
490-
z1: z1 + chunk_shape[1],
491-
z2: z2 + chunk_shape[2] if len(z) > 2 else slice(None),
492-
]
493-
chunk = np.ascontiguousarray(slice_)
494-
compressed = codec.encode(chunk)
495-
total_compressed_size += len(compressed)
496-
total_uncompressed_size += chunk.nbytes
497-
return round(total_uncompressed_size / total_compressed_size, 2)
498-
499-
500550
def compute_mae(img1, img2):
501551
"""
502552
Computes the mean absolute difference between two 3D images.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Created on Mon Nov 25 14:00:00 2024
3+
4+
@author: Anna Grim
5+
@email: anna.grim@alleninstitute.org
6+
7+
Helper routines for working with SmartSheets
8+
9+
"""
10+
11+
from datetime import datetime
12+
13+
import ast
14+
import pandas as pd
15+
16+
import smartsheet
17+
18+
19+
class SmartSheetClient:
20+
def __init__(self, access_token, sheet_name):
21+
# Instance attributes
22+
self.client = smartsheet.Smartsheet(access_token)
23+
self.sheet_name = sheet_name
24+
self.sheet_id = self.find_sheet_id()
25+
self.sheet = self.client.Sheets.get_sheet(self.sheet_id)
26+
27+
# --- Lookup Routines ---
28+
def find_sheet_id(self):
29+
response = self.client.Sheets.list_sheets()
30+
for sheet in response.data:
31+
if sheet.name == self.sheet_name:
32+
return sheet.id
33+
raise Exception(f"Sheet Not Found - sheet_name={self.sheet_name}")
34+
35+
def find_row_id(self, keyword):
36+
for row in self.sheet.rows:
37+
for cell in row.cells:
38+
if cell.display_value == keyword:
39+
return row.id
40+
raise Exception(f"Row Not Found - keyword={keyword}")
41+
42+
# --- Miscellaneous ---
43+
def get_dataframe(self):
44+
# Extract column titles
45+
columns = [col.title for col in self.sheet.columns]
46+
47+
# Extract row data
48+
data = []
49+
for row in self.sheet.rows:
50+
row_data = []
51+
for cell in row.cells:
52+
val = cell.value if cell.display_value else cell.display_value
53+
row_data.append(val)
54+
data.append(row_data)
55+
return pd.DataFrame(data, columns=columns)
56+
57+
def update_rows(self, updated_row):
58+
self.client.Sheets.update_rows(self.sheet_id, [updated_row])
59+
60+
61+
# --- Neuron Reconstruction Utils ---
62+
def extract_somas(df):
63+
idx = 0
64+
soma_locations = dict()
65+
while idx < len(df["Horta Coordinates"]):
66+
microscope = df["Horta Coordinates"][idx]
67+
if type(microscope) is str:
68+
if "spim" in microscope.lower():
69+
brain_id = str(df["ID"][idx]).split(".")[0]
70+
xyz_list = extract_somas_by_brain(df, idx + 1)
71+
if len(xyz_list) > 0:
72+
soma_locations[brain_id] = xyz_list
73+
idx += 1
74+
return soma_locations
75+
76+
77+
def extract_somas_by_brain(df, idx):
78+
xyz_list = list()
79+
while isinstance(df["Horta Coordinates"][idx], str):
80+
# Check whether to add idx
81+
entry = df["Horta Coordinates"][idx]
82+
is_coord = "[" in entry and "]" in entry
83+
if is_coord:
84+
try:
85+
xyz_list.append(ast.literal_eval(entry))
86+
except:
87+
pass
88+
89+
# Check whether reached last row
90+
idx += 1
91+
if idx >= len(df["Horta Coordinates"]):
92+
break
93+
return xyz_list

0 commit comments

Comments
 (0)