Skip to content

Commit df5b2e2

Browse files
committed
feat: more progress on downsample
1 parent 00813c5 commit df5b2e2

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

examples/create_downampled.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
# Other settings
2222
OUTPUT_PATH.mkdir(exist_ok=True, parents=True)
2323
OVERWRITE = False
24-
NUM_MIPS = 5
25-
MIP_CUTOFF = 4 # To save time you can start at the lowest resolution and work up
24+
NUM_MIPS = 3
25+
MIP_CUTOFF = 0 # To save time you can start at the lowest resolution and work up
2626

2727
# %% Load the data
2828
OUTPUT_PATH.mkdir(exist_ok=True, parents=True)
@@ -57,7 +57,6 @@ def load_chunk_from_zarr_store(
5757
]
5858
# The original timestamp was 65535, can be filterd out
5959
data = np.where(data == 65535, 0, data)
60-
print("Loaded data shape b4 sq:", data.shape)
6160
data = np.squeeze(data) # Remove any singleton dimensions
6261
print("Loaded data shape:", data.shape)
6362
# Then we permute to XYTCZ
@@ -70,9 +69,7 @@ def load_chunk_from_zarr_store(
7069
# It may take too long to just load one file, might need to process in chunks
7170
# %% Check how long to load a single file
7271
start_time = time.time()
73-
data = load_chunk_from_zarr_store(
74-
zarr_store, 0, 256, 0, 256, 0, 128, channel=0
75-
)
72+
data = load_chunk_from_zarr_store(zarr_store, 0, 256, 0, 200, 0, 128, channel=0)
7673
print("Time to load a single file:", time.time() - start_time)
7774

7875
# %% Inspect the data
@@ -90,8 +87,8 @@ def load_chunk_from_zarr_store(
9087
chunk_size = [256, 256, 128]
9188

9289
# You can provide a subset here also
93-
num_rows = 16
94-
num_cols = 24
90+
num_rows = 1
91+
num_cols = 1
9592
volume_size = [
9693
single_file_dims_shape[0] * num_cols,
9794
single_file_dims_shape[1] * num_rows,
@@ -131,35 +128,42 @@ def load_chunk_from_zarr_store(
131128
progress_dir.mkdir(exist_ok=True)
132129

133130
# %% Functions for moving data
134-
read_shape = single_file_dims_shape # this is for reading data
131+
# TODO setup file loop
132+
# TODO setup channel handling
133+
134+
shape = single_file_dims_shape
135+
chunk_shape = np.array([1500, 936, 687]) # this is for reading data
136+
num_chunks_per_dim = np.ceil(shape / chunk_shape).astype(int)
135137

136138

137139
def process(args):
138-
x_i, y_i = args
139-
start = [x_i * read_shape[0], y_i * read_shape[1], 0]
140+
x_i, y_i, z_i = args
141+
start = [x_i * chunk_shape[0], y_i * chunk_shape[1], z_i * chunk_shape[2]]
140142
end = [
141-
(x_i + 1) * read_shape[0],
142-
(y_i + 1) * read_shape[1],
143-
read_shape[2],
143+
min((x_i + 1) * chunk_shape[0], shape[0]),
144+
min((y_i + 1) * chunk_shape[1], shape[1]),
145+
min((z_i + 1) * chunk_shape[2], shape[2]),
144146
]
145147
f_name = progress_dir / f"{start[0]}-{end[0]}_{start[1]}-{end[1]}.done"
148+
print(f"Processing chunk: {start} to {end}, file: {f_name}")
146149
if f_name.exists() and not OVERWRITE:
147150
return
148-
flat_index = x_i * num_cols + y_i
149-
path = all_files[flat_index]
150-
rawdata = load_zarr_and_permute(path)[1]
151-
print("Working on", f_name)
151+
rawdata = load_chunk_from_zarr_store(
152+
zarr_store, start[0], end[0], start[1], end[1], start[2], end[2], channel=0
153+
)
152154
for mip_level in reversed(range(MIP_CUTOFF, NUM_MIPS)):
153155
if mip_level == 0:
154156
downsampled = rawdata
155157
ds_start = start
156158
ds_end = end
157159
else:
158160
downsampled = downsample_with_averaging(
159-
rawdata, [2 * mip_level, 2 * mip_level, 2 * mip_level, 1]
161+
rawdata, [2 * mip_level, 2 * mip_level, 2 * mip_level]
160162
)
161163
ds_start = [int(math.ceil(s / (2 * mip_level))) for s in start]
162164
ds_end = [int(math.ceil(e / (2 * mip_level))) for e in end]
165+
print(ds_start, ds_end)
166+
print("Downsampled shape:", downsampled.shape)
163167

164168
vols[mip_level][
165169
ds_start[0] : ds_end[0], ds_start[1] : ds_end[1], ds_start[2] : ds_end[2]
@@ -168,12 +172,16 @@ def process(args):
168172

169173

170174
# %% Try with a single chunk to see if it works
171-
x_i, y_i = 0, 0
172-
process((x_i, y_i))
175+
x_i, y_i, z_i = 0, 0, 0
176+
process((x_i, y_i, z_i))
173177

174-
# %% Loop over all the chunks
175178

176-
coords = itertools.product(range(num_rows), range(num_cols))
179+
# %% Loop over all the chunks
180+
coords = itertools.product(
181+
range(num_chunks_per_dim[0]),
182+
range(num_chunks_per_dim[1]),
183+
range(num_chunks_per_dim[2]),
184+
)
177185
# Do it in reverse order because the last chunks are most likely to error
178186
reversed_coords = list(coords)
179187
reversed_coords.reverse()

0 commit comments

Comments
 (0)