Skip to content

Commit 4589015

Browse files
committed
visualize outputs
1 parent 1bdd458 commit 4589015

File tree

1 file changed

+28
-169
lines changed

1 file changed

+28
-169
lines changed

examples/cremi/cremi.py

Lines changed: 28 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -113,204 +113,63 @@
113113
# Let's visualize the results
114114

115115
# %%
116-
117116
import matplotlib.pyplot as plt
118-
119-
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
120-
ax[0].imshow(predict_cremi.in_data.array("r")[100], cmap="gray")
121-
ax[0].set_title("Raw")
122-
ax[1].imshow(predict_cremi.out_data[0].array("r")[:3, 100].transpose(1, 2, 0))
123-
ax[1].set_title("LSDs")
124-
ax[2].imshow(predict_cremi.out_data[1].array("r")[3:6, 100].transpose(1, 2, 0))
125-
ax[2].set_title("Affinities")
126-
plt.show()
127-
128-
# %% [markdown]
129-
# Now we can convert the results to a segmentation. We will run mutex watershed on the affinities in a multi step process.
130-
# 1) Local fragment extraction - This step runs blockwise and generates fragments from the affinities. For each fragment we save a node in a graph with attributes such as its spatial position and size.
131-
# 2) Edge extraction - This step runs blockwise and computes mean affinities between fragments, adding edges to the fragment graph.
132-
# 3) Graph Mutex Watershed - This step runs on the fragment graph, and creates a lookup table from fragment id -> segment id.
133-
# 4) Relabel fragments - This step runs blockwise and creates the final segmentation.
134-
135-
# %%
136-
from volara.blockwise import AffAgglom, ExtractFrags, GraphMWS, Relabel
137-
from volara.datasets import Labels
138-
from volara.dbs import SQLite
139-
from volara.lut import LUT
140-
141-
# %% [markdown]
142-
# First lets define the graph and arrays we are going to use.
143-
144-
# because our graph is in an sql database, we need to define a schema with column names and types
145-
# for node and edge attributes.
146-
# For nodes: The defaults such as "id", "position", and "size" are already defined
147-
# so we only need to define the additional attributes, in this case we have no additional node attributes.
148-
# For edges: The defaults such as "id", "u", "v" are already defined, so we are only adding the additional
149-
# attributes "xy_aff", "z_aff", and "lr_aff" for saving the mean affinities between fragments.
150-
151-
# %%
152-
fragments_graph = SQLite(
153-
path="sample_A+_20160601.zarr/fragments.db",
154-
edge_attrs={"xy_aff": "float", "z_aff": "float", "lr_aff": "float"},
155-
)
156-
fragments_dataset = Labels(store="sample_A+_20160601.zarr/fragments")
157-
segments_dataset = Labels(store="sample_A+_20160601.zarr/segments")
158-
159-
# %% [markdown]
160-
# Now we define the tasks with the parameters we want to use.
161-
162-
# %%
163-
164-
# Generate fragments in blocks
165-
extract_frags = ExtractFrags(
166-
db=fragments_graph,
167-
affs_data=affs_dataset,
168-
frags_data=fragments_dataset,
169-
block_size=min_output_shape,
170-
context=Coordinate(3, 6, 6) * 2, # A bit larger than the longest affinity
171-
bias=[-0.6] + [-0.4] * 2 + [-0.6] * 2 + [-0.8] * 2,
172-
strides=(
173-
[Coordinate(1, 1, 1)] * 3
174-
+ [Coordinate(1, 3, 3)] * 2 # We use larger strides for larger affinities
175-
+ [Coordinate(1, 6, 6)] * 2 # This is to avoid excessive splitting
176-
),
177-
randomized_strides=True, # converts strides to probabilities of sampling affinities (1/prod(stride))
178-
remove_debris=64, # remove excessively small fragments
179-
num_workers=4,
180-
)
181-
182-
# Generate agglomerated edge scores between fragments via mean affinity accross all edges connecting two fragments
183-
aff_agglom = AffAgglom(
184-
db=fragments_graph,
185-
affs_data=affs_dataset,
186-
frags_data=fragments_dataset,
187-
block_size=min_output_shape,
188-
context=Coordinate(3, 6, 6) * 1,
189-
scores={
190-
"z_aff": affs_dataset.neighborhood[0:1],
191-
"xy_aff": affs_dataset.neighborhood[1:3],
192-
"lr_aff": affs_dataset.neighborhood[3:],
193-
},
194-
num_workers=4,
195-
)
196-
197-
# Run mutex watershed again, this time on the fragment graph with agglomerated edges
198-
# instead of the voxel graph of affinities
199-
lut = LUT(path="sample_A+_20160601.zarr/lut.npz")
200-
total_roi = raw_dataset.array("r").roi
201-
graph_mws = GraphMWS(
202-
db=fragments_graph,
203-
lut=lut,
204-
weights={"xy_aff": (1, -0.4), "z_aff": (1, -0.6), "lr_aff": (1, -0.6)},
205-
roi=(total_roi.offset, total_roi.shape),
206-
)
207-
208-
# Relabel the fragments into segments
209-
relabel = Relabel(
210-
lut=lut,
211-
frags_data=fragments_dataset,
212-
seg_data=segments_dataset,
213-
block_size=min_output_shape,
214-
num_workers=4,
215-
)
216-
217-
pipeline = extract_frags + aff_agglom + graph_mws + relabel
218-
pipeline.run_blockwise(multiprocessing=True)
219-
220-
# %% [markdown]
221-
# Let's visualize
222-
#
223-
# If you are following through on your own, I highly recommend installing `funlib.show.neuroglancer`, and
224-
# running the command line tool via `neuroglancer -d sample_A+_20160601.zarr/*` to visualize the results in
225-
# neuroglancer.
226-
#
227-
# For the purposes of visualizing here, we will make a simple gif
228-
229-
230-
# %%
231117
import matplotlib.animation as animation
232-
import matplotlib.pyplot as plt
233-
import numpy as np
234-
from matplotlib.colors import ListedColormap
235-
236-
fragments = fragments_dataset.array("r")[:, ::2, ::2]
237-
segments = segments_dataset.array("r")[:, ::2, ::2]
238-
raw = raw_dataset.array("r")[:, ::2, ::2]
239-
240-
# Get unique labels
241-
unique_labels = set(np.unique(fragments)) | set(np.unique(segments))
242-
num_labels = len(unique_labels)
243-
244-
245-
def random_color(label):
246-
rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(label)))
247-
return np.array((rs.random(), rs.random(), rs.random()))
248118

249119

250-
# Generate random colors for each label
251-
random_fragment_colors = [random_color(label) for label in range(num_labels)]
252-
253-
# Create a colormap
254-
cmap_labels = ListedColormap(random_fragment_colors)
255-
256-
# Map labels to indices for the colormap
257-
label_to_index = {label: i for i, label in enumerate(unique_labels)}
258-
indexed_fragments = np.vectorize(label_to_index.get)(fragments)
259-
indexed_segments = np.vectorize(label_to_index.get)(segments)
260-
261-
fig, axes = plt.subplots(1, 3, figsize=(18, 8))
120+
fig, axes = plt.subplots(1, 3, figsize=(14, 8))
262121

263122
ims = []
264-
for i, (raw_slice, fragments_slice, segments_slice) in enumerate(
265-
zip(raw, indexed_fragments, indexed_segments)
123+
for i, (raw_slice, affs_slice, lsd_slice) in enumerate(
124+
zip(
125+
raw_dataset.array("r")[:],
126+
affs_dataset.array("r")[:].transpose([1, 0, 2, 3]),
127+
lsds_dataset.array("r")[:].transpose([1, 0, 2, 3]),
128+
)
266129
):
267130
# Show the raw data
268131
if i == 0:
269132
im_raw = axes[0].imshow(raw_slice, cmap="gray")
270133
axes[0].set_title("Raw")
271-
im_fragments = axes[1].imshow(
272-
fragments_slice,
273-
cmap=cmap_labels,
134+
im_affs_long = axes[1].imshow(
135+
affs_slice[[0, 5, 6]].transpose([1, 2, 0]),
274136
vmin=0,
275-
vmax=num_labels,
137+
vmax=255,
276138
interpolation="none",
277139
)
278-
axes[1].set_title("Fragments")
279-
im_segments = axes[2].imshow(
280-
segments_slice,
281-
cmap=cmap_labels,
140+
axes[1].set_title("Affs (0, 5, 6)")
141+
im_lsd = axes[2].imshow(
142+
lsd_slice[:3].transpose([1, 2, 0]),
282143
vmin=0,
283-
vmax=num_labels,
144+
vmax=255,
284145
interpolation="none",
285146
)
286-
axes[2].set_title("Segments")
147+
axes[2].set_title("LSDs (0, 1, 2)")
287148
else:
288-
im_raw = axes[0].imshow(raw_slice, animated=True, cmap="gray")
289-
im_fragments = axes[1].imshow(
290-
fragments_slice,
291-
cmap=cmap_labels,
149+
im_raw = axes[0].imshow(raw_slice, cmap="gray", animated=True)
150+
axes[0].set_title("Raw")
151+
im_affs_long = axes[1].imshow(
152+
affs_slice[[0, 5, 6]].transpose([1, 2, 0]),
292153
vmin=0,
293-
vmax=num_labels,
154+
vmax=255,
294155
interpolation="none",
295156
animated=True,
296157
)
297-
im_segments = axes[2].imshow(
298-
segments_slice,
299-
cmap=cmap_labels,
158+
axes[1].set_title("Affs (0, 5, 6)")
159+
im_lsd = axes[2].imshow(
160+
lsd_slice[:3].transpose([1, 2, 0]),
300161
vmin=0,
301-
vmax=num_labels,
162+
vmax=255,
302163
interpolation="none",
303164
animated=True,
304165
)
305-
ims.append([im_raw, im_fragments, im_segments])
166+
axes[2].set_title("LSDs (0, 1, 2)")
167+
ims.append([im_raw, im_affs_long, im_lsd])
306168

307169
ims = ims + ims[::-1]
308170
ani = animation.ArtistAnimation(fig, ims, blit=True)
309-
ani.save("_static/cremi/segmentation.gif", writer="pillow", fps=10)
171+
ani.save("_static/cremi/outputs.gif", writer="pillow", fps=10)
310172
plt.close()
311173

312174
# %% [markdown]
313-
# The final segmentation is shown below. Obviously this is not a great segmentation, but it is
314-
# reasonably good for a model small enough to process a CREMI dataset in 20 minutes on a github
315-
# action.
316-
# ![segmentation](_static/cremi/segmentation.gif)
175+
# ![segmentation](_static/cremi/outputs.gif)

0 commit comments

Comments
 (0)