2424
2525CHANNEL_COUNT_KEY : Literal ["channelCount" ] = "channelCount"
2626
27+ ArrayLike : TypeAlias = np .ndarray | zarr .Array
2728ChannelIndex : TypeAlias = int
29+ ChannelKey : TypeAlias = tuple ["ChannelName" , "WaveLenOpt" , "WaveLenOpt" ]
2830ChannelName : TypeAlias = str
2931Wavelength : TypeAlias = float
3032WaveLenOpt : TypeAlias = Optional [Wavelength ]
@@ -42,11 +44,11 @@ class ZarrAxis(Enum):
4244 Z = _ZarrAxis ("z" , "space" )
4345 Y = _ZarrAxis ("y" , "space" )
4446 X = _ZarrAxis ("x" , "space" )
45-
47+
4648 @property
4749 def name (self ) -> str :
4850 return self .value .name
49-
51+
5052 @property
5153 def typename (self ) -> str :
5254 return self .value .typename
@@ -82,6 +84,9 @@ class ChannelMeta:
8284 emissionLambdaNm = attrs .field (validator = _check_null_or_positive_float ) # type: WaveLenOpt
8385 excitationLambdaNm = attrs .field (validator = _check_null_or_positive_float ) # type: WaveLenOpt
8486
87+ def get_lookup_key (self ) -> ChannelKey :
88+ return self .name , self .emissionLambdaNm , self .excitationLambdaNm
89+
8590
8691def _all_are_channels (_ : "Channels" , attr : attrs .Attribute , values : tuple [object ]) -> None :
8792 """Check that all values are channel metadata, throwing a TypeError otherwise."""
@@ -274,6 +279,7 @@ def parse_channels_from_zarr(
274279
275280
276281def create_zgroup_file (* , root : Path ) -> Path :
282+ """Create the minimal content for the ZARR group file (format), in the correct place."""
277283 outpath = root / ".zgroup"
278284 with outpath .open (mode = "x" ) as outfile :
279285 json .dump ({"zarr_format" : 2 }, outfile , indent = 4 )
@@ -304,7 +310,7 @@ def to_axis_map(self) -> OrderedDict[ZarrAxis, int]:
304310 )
305311
306312 def get_axes_data (
307- self , array : zarr . Array | np . ndarray , axis_value_pairs : Iterable [tuple [ZarrAxis , int ]]
313+ self , array : ArrayLike , axis_value_pairs : Iterable [tuple [ZarrAxis , int ]]
308314 ) -> Result [np .ndarray , list [str ]]:
309315 indexer : list [int | slice ] = []
310316 requests : Mapping [str , int ] = {ax .name : value for ax , value in axis_value_pairs }
@@ -324,7 +330,7 @@ def get_axes_data(
324330 return Result .Error (errors ) if errors else Result .Ok (array [tuple (indexer )])
325331
326332 def get_axis_data (
327- self , i : int , * , axis : ZarrAxis , array : zarr . Array | np . ndarray
333+ self , i : int , * , axis : ZarrAxis , array : ArrayLike
328334 ) -> Result [np .ndarray , str ]:
329335 if array .ndim != self .rank :
330336 return Result .Error (f"Array is rank { array .ndim } , but dimensions are rank { self .rank } " )
@@ -340,7 +346,7 @@ def get_axis_data(
340346 return Result .Ok (array [indexer ])
341347
342348 def get_channel_data (
343- self , * , channel : int , array : zarr . Array | np . ndarray
349+ self , * , channel : int , array : ArrayLike
344350 ) -> Result [np .ndarray , str ]:
345351 return self .get_axis_data (channel , axis = ZarrAxis .C , array = array )
346352
@@ -354,7 +360,7 @@ class CanonicalImageDimensions(AxisMapping):
354360 x = attrs .field (validator = _CHECK_POSITIVE_INT ) # type: int
355361
356362 def get_z_data (
357- self , * , z_slice : int , array : zarr . Array | np . ndarray
363+ self , * , z_slice : int , array : ArrayLike
358364 ) -> Result [np .ndarray , str ]:
359365 return self .get_axis_data (z_slice , axis = ZarrAxis .Z , array = array )
360366
@@ -397,7 +403,7 @@ def parse_single_array_and_dimensions_from_zarr_group(
397403 (arr , CanonicalImageDimensions (** dict (zip (axis_names , dims , strict = True ))))
398404 )
399405 except TypeError as e :
400- return Result .Error (f"Could not build image dimensions: { e } " )
406+ return Result .Error (f"Could not build image dimensions; error ( { type ( e ). __name__ } ) : { e } " )
401407 case res :
402408 return res
403409
@@ -422,7 +428,6 @@ def compute_corrected_channels(
422428 weights_channels : Channels ,
423429) -> Result [list [np .ndarray ], list [str ]]:
424430 errors : list [str ] = []
425- by_ch : list [Result [np .ndarray , list [str ]]] = []
426431 if image .ndim != image_dimensions .rank :
427432 errors .append (
428433 f"Image is of rank { image .ndim } , but dimensions are of rank { image_dimensions .rank } "
@@ -431,16 +436,29 @@ def compute_corrected_channels(
431436 errors .append (
432437 f"Weights are of rank { weights .ndim } , but dimensions are of rank { weights_dimensions .rank } "
433438 )
439+ if image_channels .count != image_dimensions .c :
440+ errors .append (f"Image channels count is { image_channels .count } but dimensions allege { image_dimensions .c } " )
441+ if weights_channels .count != weights_dimensions .c :
442+ errors .append (f"Weights channels count is { weights_channels .count } but dimensions allege { weights_dimensions .c } " )
434443 if errors :
435444 return Result .Error (errors )
436- # TODO: implement in terms of image_channels and weights_channels.
437- match _iterate_channels (dim_img = image_dimensions , dim_wts = weights_dimensions ):
438- case result .Result (tag = "ok" , ok = channels ):
439- for ch_img , ch_wts in channels :
445+
446+ channel_indices_in_weights : dict [tuple [ChannelName , WaveLenOpt , WaveLenOpt ], int ] = {
447+ ch .get_lookup_key (): i
448+ for i , ch in enumerate (weights_channels .values )
449+ }
450+ def index_channel_in_weights (ch : ChannelMeta ) -> Result [int , str ]:
451+ return Option .of_optional (channel_indices_in_weights .get (ch .get_lookup_key ())).to_result (f"Channel not defined in weights: { ch } " )
452+
453+ match traverse_accumulate_errors (lambda t : index_channel_in_weights (snd (t )).map (lambda i : (fst (t ), i )))(enumerate (image_channels .values )):
454+ case result .Result (tag = "ok" , ok = index_pairs ):
455+ errors : list [str ] = []
456+ by_ch : list [Result [np .ndarray , list [str ]]] = []
457+ for ch_img_index , ch_wts_index in index_pairs :
440458 match sequence_accumulate_errors (
441459 (
442- image_dimensions .get_channel_data (channel = ch_img , array = image ),
443- weights_dimensions .get_channel_data (channel = ch_wts , array = weights ),
460+ image_dimensions .get_channel_data (channel = ch_img_index , array = image ),
461+ weights_dimensions .get_channel_data (channel = ch_wts_index , array = weights ),
444462 )
445463 ):
446464 case result .Result (tag = "ok" , ok = (img , wts )):
@@ -451,8 +469,10 @@ def compute_corrected_channels(
451469 )
452470 case result .Result (tag = "error" , error = messages ):
453471 errors .append (
454- f"For image channel { ch_img } and weights channel { ch_wts } , { len (errors )} error(s) extracting data: { ', ' .join (messages )} "
472+ f"For image channel { ch_img_index } and weights channel { ch_wts_index } , { len (errors )} error(s) extracting data: { ', ' .join (messages )} "
455473 )
456- case result .Result (tag = "error" , error = err ):
457- return Result .Error ([err ])
458- return Result .Error (errors ) if errors else Result .Ok (by_ch )
474+ return Result .Error (errors ) if errors else Result .Ok (by_ch )
475+ case result .Result (tag = "error" , error = messages ):
476+ return Result .Error (messages )
477+ case unknown :
478+ raise TypeError (f"Expected an expression.Result-wrapped value but got a { type (unknown ).__name__ } " )
0 commit comments