Skip to content

Commit d046f3c

Browse files
committed
all tests implemented and passing for the illumination correction and computation components
1 parent 4ce0b50 commit d046f3c

File tree

3 files changed

+235
-272
lines changed

3 files changed

+235
-272
lines changed

illumifix/zarr_tools.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
CHANNEL_COUNT_KEY: Literal["channelCount"] = "channelCount"
2626

27+
ArrayLike: TypeAlias = np.ndarray | zarr.Array
2728
ChannelIndex: TypeAlias = int
29+
ChannelKey: TypeAlias = tuple["ChannelName", "WaveLenOpt", "WaveLenOpt"]
2830
ChannelName: TypeAlias = str
2931
Wavelength: TypeAlias = float
3032
WaveLenOpt: TypeAlias = Optional[Wavelength]
@@ -42,11 +44,11 @@ class ZarrAxis(Enum):
4244
Z = _ZarrAxis("z", "space")
4345
Y = _ZarrAxis("y", "space")
4446
X = _ZarrAxis("x", "space")
45-
47+
4648
@property
4749
def name(self) -> str:
4850
return self.value.name
49-
51+
5052
@property
5153
def typename(self) -> str:
5254
return self.value.typename
@@ -82,6 +84,9 @@ class ChannelMeta:
8284
emissionLambdaNm = attrs.field(validator=_check_null_or_positive_float) # type: WaveLenOpt
8385
excitationLambdaNm = attrs.field(validator=_check_null_or_positive_float) # type: WaveLenOpt
8486

87+
def get_lookup_key(self) -> ChannelKey:
88+
return self.name, self.emissionLambdaNm, self.excitationLambdaNm
89+
8590

8691
def _all_are_channels(_: "Channels", attr: attrs.Attribute, values: tuple[object]) -> None:
8792
"""Check that all values are channel metadata, throwing a TypeError otherwise."""
@@ -274,6 +279,7 @@ def parse_channels_from_zarr(
274279

275280

276281
def create_zgroup_file(*, root: Path) -> Path:
282+
"""Create the minimal content for the ZARR group file (format), in the correct place."""
277283
outpath = root / ".zgroup"
278284
with outpath.open(mode="x") as outfile:
279285
json.dump({"zarr_format": 2}, outfile, indent=4)
@@ -304,7 +310,7 @@ def to_axis_map(self) -> OrderedDict[ZarrAxis, int]:
304310
)
305311

306312
def get_axes_data(
307-
self, array: zarr.Array | np.ndarray, axis_value_pairs: Iterable[tuple[ZarrAxis, int]]
313+
self, array: ArrayLike, axis_value_pairs: Iterable[tuple[ZarrAxis, int]]
308314
) -> Result[np.ndarray, list[str]]:
309315
indexer: list[int | slice] = []
310316
requests: Mapping[str, int] = {ax.name: value for ax, value in axis_value_pairs}
@@ -324,7 +330,7 @@ def get_axes_data(
324330
return Result.Error(errors) if errors else Result.Ok(array[tuple(indexer)])
325331

326332
def get_axis_data(
327-
self, i: int, *, axis: ZarrAxis, array: zarr.Array | np.ndarray
333+
self, i: int, *, axis: ZarrAxis, array: ArrayLike
328334
) -> Result[np.ndarray, str]:
329335
if array.ndim != self.rank:
330336
return Result.Error(f"Array is rank {array.ndim}, but dimensions are rank {self.rank}")
@@ -340,7 +346,7 @@ def get_axis_data(
340346
return Result.Ok(array[indexer])
341347

342348
def get_channel_data(
343-
self, *, channel: int, array: zarr.Array | np.ndarray
349+
self, *, channel: int, array: ArrayLike
344350
) -> Result[np.ndarray, str]:
345351
return self.get_axis_data(channel, axis=ZarrAxis.C, array=array)
346352

@@ -354,7 +360,7 @@ class CanonicalImageDimensions(AxisMapping):
354360
x = attrs.field(validator=_CHECK_POSITIVE_INT) # type: int
355361

356362
def get_z_data(
357-
self, *, z_slice: int, array: zarr.Array | np.ndarray
363+
self, *, z_slice: int, array: ArrayLike
358364
) -> Result[np.ndarray, str]:
359365
return self.get_axis_data(z_slice, axis=ZarrAxis.Z, array=array)
360366

@@ -397,7 +403,7 @@ def parse_single_array_and_dimensions_from_zarr_group(
397403
(arr, CanonicalImageDimensions(**dict(zip(axis_names, dims, strict=True))))
398404
)
399405
except TypeError as e:
400-
return Result.Error(f"Could not build image dimensions: {e}")
406+
return Result.Error(f"Could not build image dimensions; error ({type(e).__name__}): {e}")
401407
case res:
402408
return res
403409

@@ -422,7 +428,6 @@ def compute_corrected_channels(
422428
weights_channels: Channels,
423429
) -> Result[list[np.ndarray], list[str]]:
424430
errors: list[str] = []
425-
by_ch: list[Result[np.ndarray, list[str]]] = []
426431
if image.ndim != image_dimensions.rank:
427432
errors.append(
428433
f"Image is of rank {image.ndim}, but dimensions are of rank {image_dimensions.rank}"
@@ -431,16 +436,29 @@ def compute_corrected_channels(
431436
errors.append(
432437
f"Weights are of rank {weights.ndim}, but dimensions are of rank {weights_dimensions.rank}"
433438
)
439+
if image_channels.count != image_dimensions.c:
440+
errors.append(f"Image channels count is {image_channels.count} but dimensions allege {image_dimensions.c}")
441+
if weights_channels.count != weights_dimensions.c:
442+
errors.append(f"Weights channels count is {weights_channels.count} but dimensions allege {weights_dimensions.c}")
434443
if errors:
435444
return Result.Error(errors)
436-
# TODO: implement in terms of image_channels and weights_channels.
437-
match _iterate_channels(dim_img=image_dimensions, dim_wts=weights_dimensions):
438-
case result.Result(tag="ok", ok=channels):
439-
for ch_img, ch_wts in channels:
445+
446+
channel_indices_in_weights: dict[tuple[ChannelName, WaveLenOpt, WaveLenOpt], int] = {
447+
ch.get_lookup_key(): i
448+
for i, ch in enumerate(weights_channels.values)
449+
}
450+
def index_channel_in_weights(ch: ChannelMeta) -> Result[int, str]:
451+
return Option.of_optional(channel_indices_in_weights.get(ch.get_lookup_key())).to_result(f"Channel not defined in weights: {ch}")
452+
453+
match traverse_accumulate_errors(lambda t: index_channel_in_weights(snd(t)).map(lambda i: (fst(t), i)))(enumerate(image_channels.values)):
454+
case result.Result(tag="ok", ok=index_pairs):
455+
errors: list[str] = []
456+
by_ch: list[Result[np.ndarray, list[str]]] = []
457+
for ch_img_index, ch_wts_index in index_pairs:
440458
match sequence_accumulate_errors(
441459
(
442-
image_dimensions.get_channel_data(channel=ch_img, array=image),
443-
weights_dimensions.get_channel_data(channel=ch_wts, array=weights),
460+
image_dimensions.get_channel_data(channel=ch_img_index, array=image),
461+
weights_dimensions.get_channel_data(channel=ch_wts_index, array=weights),
444462
)
445463
):
446464
case result.Result(tag="ok", ok=(img, wts)):
@@ -451,8 +469,10 @@ def compute_corrected_channels(
451469
)
452470
case result.Result(tag="error", error=messages):
453471
errors.append(
454-
f"For image channel {ch_img} and weights channel {ch_wts}, {len(errors)} error(s) extracting data: {', '.join(messages)}"
472+
f"For image channel {ch_img_index} and weights channel {ch_wts_index}, {len(errors)} error(s) extracting data: {', '.join(messages)}"
455473
)
456-
case result.Result(tag="error", error=err):
457-
return Result.Error([err])
458-
return Result.Error(errors) if errors else Result.Ok(by_ch)
474+
return Result.Error(errors) if errors else Result.Ok(by_ch)
475+
case result.Result(tag="error", error=messages):
476+
return Result.Error(messages)
477+
case unknown:
478+
raise TypeError(f"Expected an expression.Result-wrapped value but got a {type(unknown).__name__}")

scripts/compute_illumination_correction.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
import re
99
import sys
10-
from pathlib import Path
10+
from pathlib import Path
1111
from typing import Callable, Iterable, Mapping, TypeAlias
1212

1313
import attrs
@@ -23,11 +23,10 @@
2323
from illumifix.zarr_tools import (
2424
CanonicalImageDimensions,
2525
ChannelIndex,
26+
ChannelKey,
2627
ChannelMeta,
27-
ChannelName,
2828
Channels,
2929
JsonEncoderForChannelMeta,
30-
WaveLenOpt,
3130
ZarrAxis,
3231
extract_single_channel_single_z_data,
3332
parse_channels_from_flattened_mapping_with_count,
@@ -183,6 +182,22 @@ def expand_target_folder(folder: ExtantFolder) -> Result[list[ExtantFolder], Pat
183182
return Result.Ok(goods)
184183

185184

185+
def get_sorted_channels_representation(channels: Channels) -> Channels:
186+
"""Order the metadata instances within the wrapper class."""
187+
ordered: list[ChannelMeta] = sorted(
188+
channels.values,
189+
key=lambda ch: (
190+
None if ch.emissionLambdaNm is None else -ch.emissionLambdaNm,
191+
ch.name,
192+
None if ch.excitationLambdaNm is None else -ch.excitationLambdaNm,
193+
),
194+
)
195+
return Channels(
196+
count=channels.count,
197+
values=tuple(attrs.evolve(ch, index=i) for i, ch in enumerate(ordered)),
198+
)
199+
200+
186201
def determine_channel_order(
187202
*, new_channels: Channels, ref_channels: Channels
188203
) -> Result[list[int], str]:
@@ -193,40 +208,16 @@ def determine_channel_order(
193208
return Result.Error(
194209
f"New channels' count is {new_channels.count}, but reference count is {ref_channels.count}"
195210
)
196-
ChannelKey: TypeAlias = tuple[ChannelName, WaveLenOpt, WaveLenOpt]
197-
get_ch_key: Callable[[ChannelMeta], ChannelKey] = lambda ch: (
198-
ch.name,
199-
ch.emissionLambdaNm,
200-
ch.excitationLambdaNm,
201-
)
202-
lookup: Mapping[ChannelKey, int] = {
203-
(ch.name, ch.emissionLambdaNm, ch.excitationLambdaNm): i
204-
for i, ch in enumerate(ref_channels.values)
205-
}
211+
212+
lookup: Mapping[ChannelKey, int] = {ch.get_lookup_key(): i for i, ch in enumerate(ref_channels.values)}
206213
get_ch_index: Callable[[ChannelMeta], Result[ChannelIndex, ChannelMeta]] = (
207-
lambda ch: Option.of_optional(lookup.get(get_ch_key(ch))).to_result(ch)
214+
lambda ch: Option.of_optional(lookup.get(ch.get_lookup_key())).to_result(ch)
208215
)
209216
return traverse_accumulate_errors(get_ch_index)(new_channels.values).map_error(
210217
lambda bad_chs: f"{len(bad_chs)} channel(s) cannot be resolved: {bad_chs}"
211218
)
212219

213220

214-
def get_sorted_channels_representation(channels: Channels) -> Channels:
215-
"""Order the metadata instances within the wrapper class."""
216-
ordered: list[ChannelMeta] = sorted(
217-
channels.values,
218-
key=lambda ch: (
219-
None if ch.emissionLambdaNm is None else -ch.emissionLambdaNm,
220-
ch.name,
221-
None if ch.excitationLambdaNm is None else -ch.excitationLambdaNm,
222-
),
223-
)
224-
return Channels(
225-
count=channels.count,
226-
values=tuple(attrs.evolve(ch, index=i) for i, ch in enumerate(ordered)),
227-
)
228-
229-
230221
def find_channel_rotations(
231222
path_arrdim_pairs: list[tuple[ExtantFolder, ArrayWithDims]],
232223
) -> Result[

0 commit comments

Comments
 (0)