|
113 | 113 | # Let's visualize the results |
114 | 114 |
|
115 | 115 | # %% |
116 | | - |
117 | 116 | import matplotlib.pyplot as plt |
118 | | - |
119 | | -fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |
120 | | -ax[0].imshow(predict_cremi.in_data.array("r")[100], cmap="gray") |
121 | | -ax[0].set_title("Raw") |
122 | | -ax[1].imshow(predict_cremi.out_data[0].array("r")[:3, 100].transpose(1, 2, 0)) |
123 | | -ax[1].set_title("LSDs") |
124 | | -ax[2].imshow(predict_cremi.out_data[1].array("r")[3:6, 100].transpose(1, 2, 0)) |
125 | | -ax[2].set_title("Affinities") |
126 | | -plt.show() |
127 | | - |
128 | | -# %% [markdown] |
129 | | -# Now we can convert the results to a segmentation. We will run mutex watershed on the affinities in a multi step process. |
130 | | -# 1) Local fragment extraction - This step runs blockwise and generates fragments from the affinities. For each fragment we save a node in a graph with attributes such as its spatial position and size. |
131 | | -# 2) Edge extraction - This step runs blockwise and computes mean affinities between fragments, adding edges to the fragment graph. |
132 | | -# 3) Graph Mutex Watershed - This step runs on the fragment graph, and creates a lookup table from fragment id -> segment id. |
133 | | -# 4) Relabel fragments - This step runs blockwise and creates the final segmentation. |
134 | | - |
135 | | -# %% |
136 | | -from volara.blockwise import AffAgglom, ExtractFrags, GraphMWS, Relabel |
137 | | -from volara.datasets import Labels |
138 | | -from volara.dbs import SQLite |
139 | | -from volara.lut import LUT |
140 | | - |
141 | | -# %% [markdown] |
142 | | -# First lets define the graph and arrays we are going to use. |
143 | | - |
144 | | -# because our graph is in an sql database, we need to define a schema with column names and types |
145 | | -# for node and edge attributes. |
146 | | -# For nodes: The defaults such as "id", "position", and "size" are already defined |
147 | | -# so we only need to define the additional attributes, in this case we have no additional node attributes. |
148 | | -# For edges: The defaults such as "id", "u", "v" are already defined, so we are only adding the additional |
149 | | -# attributes "xy_aff", "z_aff", and "lr_aff" for saving the mean affinities between fragments. |
150 | | - |
151 | | -# %% |
152 | | -fragments_graph = SQLite( |
153 | | - path="sample_A+_20160601.zarr/fragments.db", |
154 | | - edge_attrs={"xy_aff": "float", "z_aff": "float", "lr_aff": "float"}, |
155 | | -) |
156 | | -fragments_dataset = Labels(store="sample_A+_20160601.zarr/fragments") |
157 | | -segments_dataset = Labels(store="sample_A+_20160601.zarr/segments") |
158 | | - |
159 | | -# %% [markdown] |
160 | | -# Now we define the tasks with the parameters we want to use. |
161 | | - |
162 | | -# %% |
163 | | - |
164 | | -# Generate fragments in blocks |
165 | | -extract_frags = ExtractFrags( |
166 | | - db=fragments_graph, |
167 | | - affs_data=affs_dataset, |
168 | | - frags_data=fragments_dataset, |
169 | | - block_size=min_output_shape, |
170 | | - context=Coordinate(3, 6, 6) * 2, # A bit larger than the longest affinity |
171 | | - bias=[-0.6] + [-0.4] * 2 + [-0.6] * 2 + [-0.8] * 2, |
172 | | - strides=( |
173 | | - [Coordinate(1, 1, 1)] * 3 |
174 | | - + [Coordinate(1, 3, 3)] * 2 # We use larger strides for larger affinities |
175 | | - + [Coordinate(1, 6, 6)] * 2 # This is to avoid excessive splitting |
176 | | - ), |
177 | | - randomized_strides=True, # converts strides to probabilities of sampling affinities (1/prod(stride)) |
178 | | - remove_debris=64, # remove excessively small fragments |
179 | | - num_workers=4, |
180 | | -) |
181 | | - |
182 | | -# Generate agglomerated edge scores between fragments via mean affinity accross all edges connecting two fragments |
183 | | -aff_agglom = AffAgglom( |
184 | | - db=fragments_graph, |
185 | | - affs_data=affs_dataset, |
186 | | - frags_data=fragments_dataset, |
187 | | - block_size=min_output_shape, |
188 | | - context=Coordinate(3, 6, 6) * 1, |
189 | | - scores={ |
190 | | - "z_aff": affs_dataset.neighborhood[0:1], |
191 | | - "xy_aff": affs_dataset.neighborhood[1:3], |
192 | | - "lr_aff": affs_dataset.neighborhood[3:], |
193 | | - }, |
194 | | - num_workers=4, |
195 | | -) |
196 | | - |
197 | | -# Run mutex watershed again, this time on the fragment graph with agglomerated edges |
198 | | -# instead of the voxel graph of affinities |
199 | | -lut = LUT(path="sample_A+_20160601.zarr/lut.npz") |
200 | | -total_roi = raw_dataset.array("r").roi |
201 | | -graph_mws = GraphMWS( |
202 | | - db=fragments_graph, |
203 | | - lut=lut, |
204 | | - weights={"xy_aff": (1, -0.4), "z_aff": (1, -0.6), "lr_aff": (1, -0.6)}, |
205 | | - roi=(total_roi.offset, total_roi.shape), |
206 | | -) |
207 | | - |
208 | | -# Relabel the fragments into segments |
209 | | -relabel = Relabel( |
210 | | - lut=lut, |
211 | | - frags_data=fragments_dataset, |
212 | | - seg_data=segments_dataset, |
213 | | - block_size=min_output_shape, |
214 | | - num_workers=4, |
215 | | -) |
216 | | - |
217 | | -pipeline = extract_frags + aff_agglom + graph_mws + relabel |
218 | | -pipeline.run_blockwise(multiprocessing=True) |
219 | | - |
220 | | -# %% [markdown] |
221 | | -# Let's visualize |
222 | | -# |
223 | | -# If you are following through on your own, I highly recommend installing `funlib.show.neuroglancer`, and |
224 | | -# running the command line tool via `neuroglancer -d sample_A+_20160601.zarr/*` to visualize the results in |
225 | | -# neuroglancer. |
226 | | -# |
227 | | -# For the purposes of visualizing here, we will make a simple gif |
228 | | - |
229 | | - |
230 | | -# %% |
231 | 117 | import matplotlib.animation as animation |
232 | | -import matplotlib.pyplot as plt |
233 | | -import numpy as np |
234 | | -from matplotlib.colors import ListedColormap |
235 | | - |
236 | | -fragments = fragments_dataset.array("r")[:, ::2, ::2] |
237 | | -segments = segments_dataset.array("r")[:, ::2, ::2] |
238 | | -raw = raw_dataset.array("r")[:, ::2, ::2] |
239 | | - |
240 | | -# Get unique labels |
241 | | -unique_labels = set(np.unique(fragments)) | set(np.unique(segments)) |
242 | | -num_labels = len(unique_labels) |
243 | | - |
244 | | - |
245 | | -def random_color(label): |
246 | | - rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(label))) |
247 | | - return np.array((rs.random(), rs.random(), rs.random())) |
248 | 118 |
|
249 | 119 |
|
250 | | -# Generate random colors for each label |
251 | | -random_fragment_colors = [random_color(label) for label in range(num_labels)] |
252 | | - |
253 | | -# Create a colormap |
254 | | -cmap_labels = ListedColormap(random_fragment_colors) |
255 | | - |
256 | | -# Map labels to indices for the colormap |
257 | | -label_to_index = {label: i for i, label in enumerate(unique_labels)} |
258 | | -indexed_fragments = np.vectorize(label_to_index.get)(fragments) |
259 | | -indexed_segments = np.vectorize(label_to_index.get)(segments) |
260 | | - |
261 | | -fig, axes = plt.subplots(1, 3, figsize=(18, 8)) |
| 120 | +fig, axes = plt.subplots(1, 3, figsize=(14, 8)) |
262 | 121 |
|
263 | 122 | ims = [] |
264 | | -for i, (raw_slice, fragments_slice, segments_slice) in enumerate( |
265 | | - zip(raw, indexed_fragments, indexed_segments) |
| 123 | +for i, (raw_slice, affs_slice, lsd_slice) in enumerate( |
| 124 | + zip( |
| 125 | + raw_dataset.array("r")[:], |
| 126 | + affs_dataset.array("r")[:].transpose([1, 0, 2, 3]), |
| 127 | + lsds_dataset.array("r")[:].transpose([1, 0, 2, 3]), |
| 128 | + ) |
266 | 129 | ): |
267 | 130 | # Show the raw data |
268 | 131 | if i == 0: |
269 | 132 | im_raw = axes[0].imshow(raw_slice, cmap="gray") |
270 | 133 | axes[0].set_title("Raw") |
271 | | - im_fragments = axes[1].imshow( |
272 | | - fragments_slice, |
273 | | - cmap=cmap_labels, |
| 134 | + im_affs_long = axes[1].imshow( |
| 135 | + affs_slice[[0, 5, 6]].transpose([1, 2, 0]), |
274 | 136 | vmin=0, |
275 | | - vmax=num_labels, |
| 137 | + vmax=255, |
276 | 138 | interpolation="none", |
277 | 139 | ) |
278 | | - axes[1].set_title("Fragments") |
279 | | - im_segments = axes[2].imshow( |
280 | | - segments_slice, |
281 | | - cmap=cmap_labels, |
| 140 | + axes[1].set_title("Affs (0, 5, 6)") |
| 141 | + im_lsd = axes[2].imshow( |
| 142 | + lsd_slice[:3].transpose([1, 2, 0]), |
282 | 143 | vmin=0, |
283 | | - vmax=num_labels, |
| 144 | + vmax=255, |
284 | 145 | interpolation="none", |
285 | 146 | ) |
286 | | - axes[2].set_title("Segments") |
| 147 | + axes[2].set_title("LSDs (0, 1, 2)") |
287 | 148 | else: |
288 | | - im_raw = axes[0].imshow(raw_slice, animated=True, cmap="gray") |
289 | | - im_fragments = axes[1].imshow( |
290 | | - fragments_slice, |
291 | | - cmap=cmap_labels, |
| 149 | + im_raw = axes[0].imshow(raw_slice, cmap="gray", animated=True) |
| 150 | + axes[0].set_title("Raw") |
| 151 | + im_affs_long = axes[1].imshow( |
| 152 | + affs_slice[[0, 5, 6]].transpose([1, 2, 0]), |
292 | 153 | vmin=0, |
293 | | - vmax=num_labels, |
| 154 | + vmax=255, |
294 | 155 | interpolation="none", |
295 | 156 | animated=True, |
296 | 157 | ) |
297 | | - im_segments = axes[2].imshow( |
298 | | - segments_slice, |
299 | | - cmap=cmap_labels, |
| 158 | + axes[1].set_title("Affs (0, 5, 6)") |
| 159 | + im_lsd = axes[2].imshow( |
| 160 | + lsd_slice[:3].transpose([1, 2, 0]), |
300 | 161 | vmin=0, |
301 | | - vmax=num_labels, |
| 162 | + vmax=255, |
302 | 163 | interpolation="none", |
303 | 164 | animated=True, |
304 | 165 | ) |
305 | | - ims.append([im_raw, im_fragments, im_segments]) |
| 166 | + axes[2].set_title("LSDs (0, 1, 2)") |
| 167 | + ims.append([im_raw, im_affs_long, im_lsd]) |
306 | 168 |
|
307 | 169 | ims = ims + ims[::-1] |
308 | 170 | ani = animation.ArtistAnimation(fig, ims, blit=True) |
309 | | -ani.save("_static/cremi/segmentation.gif", writer="pillow", fps=10) |
| 171 | +ani.save("_static/cremi/outputs.gif", writer="pillow", fps=10) |
310 | 172 | plt.close() |
311 | 173 |
|
312 | 174 | # %% [markdown] |
313 | | -# The final segmentation is shown below. Obviously this is not a great segmentation, but it is |
314 | | -# reasonably good for a model small enough to process a CREMI dataset in 20 minutes on a github |
315 | | -# action. |
316 | | -#  |
| 175 | +#  |
0 commit comments