22
33from collections .abc import Callable
44from pathlib import Path
5- from typing import Any , Literal , Optional , Union , overload
5+ from typing import Any , Literal , Union , overload
66
77import numpy as np
88from numpy .typing import NDArray
@@ -79,25 +79,25 @@ class CAREamist:
7979 def __init__ ( # numpydoc ignore=GL08
8080 self ,
8181 source : Union [Path , str ],
82- work_dir : Optional [ Union [Path , str ]] = None ,
83- callbacks : Optional [ list [Callback ]] = None ,
82+ work_dir : Union [Path , str ] | None = None ,
83+ callbacks : list [Callback ] | None = None ,
8484 enable_progress_bar : bool = True ,
8585 ) -> None : ...
8686
8787 @overload
8888 def __init__ ( # numpydoc ignore=GL08
8989 self ,
9090 source : Configuration ,
91- work_dir : Optional [ Union [Path , str ]] = None ,
92- callbacks : Optional [ list [Callback ]] = None ,
91+ work_dir : Union [Path , str ] | None = None ,
92+ callbacks : list [Callback ] | None = None ,
9393 enable_progress_bar : bool = True ,
9494 ) -> None : ...
9595
9696 def __init__ (
9797 self ,
9898 source : Union [Path , str , Configuration ],
99- work_dir : Optional [ Union [Path , str ]] = None ,
100- callbacks : Optional [ list [Callback ]] = None ,
99+ work_dir : Union [Path , str ] | None = None ,
100+ callbacks : list [Callback ] | None = None ,
101101 enable_progress_bar : bool = True ,
102102 ) -> None :
103103 """
@@ -222,11 +222,11 @@ def __init__(
222222 )
223223
224224 # place holder for the datamodules
225- self .train_datamodule : Optional [ TrainDataModule ] = None
226- self .pred_datamodule : Optional [ PredictDataModule ] = None
225+ self .train_datamodule : TrainDataModule | None = None
226+ self .pred_datamodule : PredictDataModule | None = None
227227
228228 def _define_callbacks (
229- self , callbacks : Optional [ list [Callback ]] , enable_progress_bar : bool
229+ self , callbacks : list [Callback ] | None , enable_progress_bar : bool
230230 ) -> None :
231231 """Define the callbacks for the training loop.
232232
@@ -288,11 +288,11 @@ def stop_training(self) -> None:
288288 def train (
289289 self ,
290290 * ,
291- datamodule : Optional [ TrainDataModule ] = None ,
292- train_source : Optional [ Union [Path , str , NDArray ]] = None ,
293- val_source : Optional [ Union [Path , str , NDArray ]] = None ,
294- train_target : Optional [ Union [Path , str , NDArray ]] = None ,
295- val_target : Optional [ Union [Path , str , NDArray ]] = None ,
291+ datamodule : TrainDataModule | None = None ,
292+ train_source : Union [Path , str , NDArray ] | None = None ,
293+ val_source : Union [Path , str , NDArray ] | None = None ,
294+ train_target : Union [Path , str , NDArray ] | None = None ,
295+ val_target : Union [Path , str , NDArray ] | None = None ,
296296 use_in_memory : bool = True ,
297297 val_percentage : float = 0.1 ,
298298 val_minimum_split : int = 1 ,
@@ -443,9 +443,9 @@ def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
443443 def _train_on_array (
444444 self ,
445445 train_data : NDArray ,
446- val_data : Optional [ NDArray ] = None ,
447- train_target : Optional [ NDArray ] = None ,
448- val_target : Optional [ NDArray ] = None ,
446+ val_data : NDArray | None = None ,
447+ train_target : NDArray | None = None ,
448+ val_target : NDArray | None = None ,
449449 val_percentage : float = 0.1 ,
450450 val_minimum_split : int = 5 ,
451451 ) -> None :
@@ -484,9 +484,9 @@ def _train_on_array(
484484 def _train_on_path (
485485 self ,
486486 path_to_train_data : Union [Path , str ],
487- path_to_val_data : Optional [ Union [Path , str ]] = None ,
488- path_to_train_target : Optional [ Union [Path , str ]] = None ,
489- path_to_val_target : Optional [ Union [Path , str ]] = None ,
487+ path_to_val_data : Union [Path , str ] | None = None ,
488+ path_to_train_target : Union [Path , str ] | None = None ,
489+ path_to_val_target : Union [Path , str ] | None = None ,
490490 use_in_memory : bool = True ,
491491 val_percentage : float = 0.1 ,
492492 val_minimum_split : int = 1 ,
@@ -549,13 +549,13 @@ def predict( # numpydoc ignore=GL08
549549 source : Union [Path , str ],
550550 * ,
551551 batch_size : int = 1 ,
552- tile_size : Optional [ tuple [int , ...]] = None ,
553- tile_overlap : Optional [ tuple [int , ...]] = (48 , 48 ),
554- axes : Optional [ str ] = None ,
555- data_type : Optional [ Literal ["tiff" , "custom" ]] = None ,
552+ tile_size : tuple [int , ...] | None = None ,
553+ tile_overlap : tuple [int , ...] | None = (48 , 48 ),
554+ axes : str | None = None ,
555+ data_type : Literal ["tiff" , "custom" ] | None = None ,
556556 tta_transforms : bool = False ,
557- dataloader_params : Optional [ dict ] = None ,
558- read_source_func : Optional [ Callable ] = None ,
557+ dataloader_params : dict | None = None ,
558+ read_source_func : Callable | None = None ,
559559 extension_filter : str = "" ,
560560 ) -> Union [list [NDArray ], NDArray ]: ...
561561
@@ -565,26 +565,26 @@ def predict( # numpydoc ignore=GL08
565565 source : NDArray ,
566566 * ,
567567 batch_size : int = 1 ,
568- tile_size : Optional [ tuple [int , ...]] = None ,
569- tile_overlap : Optional [ tuple [int , ...]] = (48 , 48 ),
570- axes : Optional [ str ] = None ,
571- data_type : Optional [ Literal ["array" ]] = None ,
568+ tile_size : tuple [int , ...] | None = None ,
569+ tile_overlap : tuple [int , ...] | None = (48 , 48 ),
570+ axes : str | None = None ,
571+ data_type : Literal ["array" ] | None = None ,
572572 tta_transforms : bool = False ,
573- dataloader_params : Optional [ dict ] = None ,
573+ dataloader_params : dict | None = None ,
574574 ) -> Union [list [NDArray ], NDArray ]: ...
575575
576576 def predict (
577577 self ,
578578 source : Union [PredictDataModule , Path , str , NDArray ],
579579 * ,
580580 batch_size : int = 1 ,
581- tile_size : Optional [ tuple [int , ...]] = None ,
582- tile_overlap : Optional [ tuple [int , ...]] = (48 , 48 ),
583- axes : Optional [ str ] = None ,
584- data_type : Optional [ Literal ["array" , "tiff" , "custom" ]] = None ,
581+ tile_size : tuple [int , ...] | None = None ,
582+ tile_overlap : tuple [int , ...] | None = (48 , 48 ),
583+ axes : str | None = None ,
584+ data_type : Literal ["array" , "tiff" , "custom" ] | None = None ,
585585 tta_transforms : bool = False ,
586- dataloader_params : Optional [ dict ] = None ,
587- read_source_func : Optional [ Callable ] = None ,
586+ dataloader_params : dict | None = None ,
587+ read_source_func : Callable | None = None ,
588588 extension_filter : str = "" ,
589589 ** kwargs : Any ,
590590 ) -> Union [list [NDArray ], NDArray ]:
@@ -704,18 +704,18 @@ def predict_to_disk(
704704 source : Union [PredictDataModule , Path , str ],
705705 * ,
706706 batch_size : int = 1 ,
707- tile_size : Optional [ tuple [int , ...]] = None ,
708- tile_overlap : Optional [ tuple [int , ...]] = (48 , 48 ),
709- axes : Optional [ str ] = None ,
710- data_type : Optional [ Literal ["tiff" , "custom" ]] = None ,
707+ tile_size : tuple [int , ...] | None = None ,
708+ tile_overlap : tuple [int , ...] | None = (48 , 48 ),
709+ axes : str | None = None ,
710+ data_type : Literal ["tiff" , "custom" ] | None = None ,
711711 tta_transforms : bool = False ,
712- dataloader_params : Optional [ dict ] = None ,
713- read_source_func : Optional [ Callable ] = None ,
712+ dataloader_params : dict | None = None ,
713+ read_source_func : Callable | None = None ,
714714 extension_filter : str = "" ,
715715 write_type : Literal ["tiff" , "custom" ] = "tiff" ,
716- write_extension : Optional [ str ] = None ,
717- write_func : Optional [ WriteFunc ] = None ,
718- write_func_kwargs : Optional [ dict [str , Any ]] = None ,
716+ write_extension : str | None = None ,
717+ write_func : WriteFunc | None = None ,
718+ write_func_kwargs : dict [str , Any ] | None = None ,
719719 prediction_dir : Union [Path , str ] = "predictions" ,
720720 ** kwargs ,
721721 ) -> None :
@@ -885,8 +885,8 @@ def export_to_bmz(
885885 authors : list [dict ],
886886 general_description : str ,
887887 data_description : str ,
888- covers : Optional [ list [Union [Path , str ]]] = None ,
889- channel_names : Optional [ list [str ]] = None ,
888+ covers : list [Union [Path , str ]] | None = None ,
889+ channel_names : list [str ] | None = None ,
890890 model_version : str = "0.1.0" ,
891891 ) -> None :
892892 """Export the model to the BioImage Model Zoo format.
0 commit comments