Skip to content

Commit f1478de

Browse files
committed
fix: more downsample fix
1 parent 68f6e80 commit f1478de

File tree

1 file changed

+33
-30
lines changed

1 file changed

+33
-30
lines changed

examples/create_downampled.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
NUM_MIPS = 5
2424
MIP_CUTOFF = 3 # To save time you can start at the lowest resolution and work up
2525
NUM_CHANNELS = 2 # For less memory usage (can't be 1 right now though)
26+
NUM_ROWS = 3
27+
NUM_COLS = 6
28+
ALLOW_NON_ALIGNED_WRITE = True
2629

2730
# %% Load the data
2831
OUTPUT_PATH.mkdir(exist_ok=True, parents=True)
@@ -44,16 +47,14 @@ def load_zarr_and_permute(file_path):
4447
return zarr_store, data
4548

4649

47-
def load_chunk_from_zarr_store(
48-
zarr_store, x_start, x_end, y_start, y_end, z_start, z_end
49-
):
50+
def load_data_from_zarr_store(zarr_store):
5051
# Input is in Z, T, C, Y, X order
5152
data = zarr_store[
5253
:, # T
53-
z_start:z_end, # Z
54+
:, # Z
5455
:NUM_CHANNELS, # C
55-
y_start:y_end, # Y
56-
x_start:x_end, # X
56+
:, # Y
57+
:, # X
5758
]
5859
# The original timestamp was 65535, can be filterd out
5960
data = np.where(data == 65535, 0, data)
@@ -66,13 +67,6 @@ def load_chunk_from_zarr_store(
6667

6768
zarr_store = load_zarr_data(all_files[0])
6869

69-
# It may take too long to just load one file, might need to process in chunks
70-
# %% Check how long to load a single file
71-
# import time
72-
# start_time = time.time()
73-
# data = load_chunk_from_zarr_store(zarr_store, 0, 256, 0, 200, 0, 128)
74-
# print("Time to load a single file:", time.time() - start_time)
75-
7670
# %% Inspect the data
7771
shape = zarr_store.shape
7872
# Input is in Z, T, C, Y, X order
@@ -87,17 +81,15 @@ def load_chunk_from_zarr_store(
8781
data_type = "uint16"
8882
chunk_size = [64, 64, 32]
8983

90-
# You can provide a subset here also
91-
num_rows = 1
92-
num_cols = 1
9384
volume_size = [
94-
single_file_dims_shape[0] * num_cols,
95-
single_file_dims_shape[1] * num_rows,
85+
single_file_dims_shape[0] * NUM_ROWS,
86+
single_file_dims_shape[1] * NUM_COLS,
9687
single_file_dims_shape[2],
9788
] # XYZ (T)
9889
print("Volume size:", volume_size)
9990

10091
# %% Setup the cloudvolume info
92+
# TODO verify if non-axis aligned is ok or not
10193
info = CloudVolume.create_new_info(
10294
num_channels=num_channels,
10395
layer_type="image",
@@ -114,6 +106,8 @@ def load_chunk_from_zarr_store(
114106
"file://" + str(OUTPUT_PATH),
115107
info=info,
116108
mip=0,
109+
non_aligned_writes=ALLOW_NON_ALIGNED_WRITE,
110+
fill_missing=True,
117111
)
118112
vol.commit_info()
119113
vol.provenance.description = "Example data conversion"
@@ -122,21 +116,27 @@ def load_chunk_from_zarr_store(
122116

123117
# %% Create the volumes for each mip level and hold progress
124118
vols = [
125-
CloudVolume("file://" + str(OUTPUT_PATH), mip=i, compress=False)
119+
CloudVolume(
120+
"file://" + str(OUTPUT_PATH),
121+
mip=i,
122+
compress=False,
123+
non_aligned_writes=ALLOW_NON_ALIGNED_WRITE,
124+
fill_missing=True,
125+
)
126126
for i in range(NUM_MIPS)
127127
]
128128
progress_dir = OUTPUT_PATH / "progress"
129129
progress_dir.mkdir(exist_ok=True)
130130

131131
# %% Functions for moving data
132-
shape = single_file_dims_shape
132+
shape = volume_size
133133
chunk_shape = np.array([1500, 936, 687]) # this is for reading data
134134
num_chunks_per_dim = np.ceil(shape / chunk_shape).astype(int)
135135

136136

137137
def process(args):
138138
x_i, y_i, z_i = args
139-
flat_index = x_i * num_cols + y_i
139+
flat_index = x_i * NUM_COLS + y_i
140140
print(f"Processing chunk {flat_index} at coordinates ({x_i}, {y_i}, {z_i})")
141141
# Load the data for this chunk
142142
loaded_zarr_store = load_zarr_data(all_files[flat_index])
@@ -150,26 +150,29 @@ def process(args):
150150
print(f"Processing chunk: {start} to {end}, file: {f_name}")
151151
if f_name.exists() and not OVERWRITE:
152152
return
153-
rawdata = load_chunk_from_zarr_store(
154-
loaded_zarr_store, start[0], end[0], start[1], end[1], start[2], end[2]
155-
)
153+
rawdata = load_data_from_zarr_store(loaded_zarr_store)
156154
for mip_level in reversed(range(MIP_CUTOFF, NUM_MIPS)):
157155
if mip_level == 0:
158156
downsampled = rawdata
159157
ds_start = start
160158
ds_end = end
161159
else:
162-
ds_start = [int(math.ceil(s / (2**mip_level))) for s in start]
163-
ds_end = [int(math.ceil(e / (2**mip_level))) for e in end]
160+
factor = 2**mip_level
161+
factor_tuple = (factor, factor, factor, 1)
162+
ds_start = [int(np.round(s / (2**mip_level))) for s in start]
163+
bounds_from_end = [int(math.ceil(e / (2**mip_level))) for e in end]
164+
downsample_shape = [
165+
int(math.ceil(s / f)) for s, f in zip(rawdata.shape, factor_tuple)
166+
]
167+
ds_end_est = [s + d for s, d in zip(ds_start, downsample_shape)]
168+
ds_end = [max(e1, e2) for e1, e2 in zip(ds_end_est, bounds_from_end)]
164169
print("DS fill", ds_start, ds_end)
165-
downsampled = downsample_with_averaging(
166-
rawdata, [2**mip_level, 2**mip_level, 2**mip_level, 1]
167-
)
170+
downsampled = downsample_with_averaging(rawdata, factor_tuple)
168171
print("Downsampled shape:", downsampled.shape)
169172

170173
vols[mip_level][
171174
ds_start[0] : ds_end[0], ds_start[1] : ds_end[1], ds_start[2] : ds_end[2]
172-
] = downsampled.astype(np.uint16)
175+
] = downsampled
173176
touch(f_name)
174177

175178

0 commit comments

Comments
 (0)