77
88def get_3d_model (
99 out_channels : int ,
10+ in_channels : int = 1 ,
1011 scale_factors : Tuple [Tuple [int , int , int ]] = [[1 , 2 , 2 ], [2 , 2 , 2 ], [2 , 2 , 2 ], [2 , 2 , 2 ]],
1112 initial_features : int = 32 ,
1213 final_activation : str = "Sigmoid" ,
@@ -25,7 +26,7 @@ def get_3d_model(
2526 """
2627 model = AnisotropicUNet (
2728 scale_factors = scale_factors ,
28- in_channels = 1 ,
29+ in_channels = in_channels ,
2930 out_channels = out_channels ,
3031 initial_features = initial_features ,
3132 gain = 2 ,
@@ -36,6 +37,7 @@ def get_3d_model(
3637
3738def get_2d_model (
3839 out_channels : int ,
40+ in_channels : int = 1 ,
3941 initial_features : int = 32 ,
4042 final_activation : str = "Sigmoid" ,
4143) -> torch .nn .Module :
@@ -51,7 +53,7 @@ def get_2d_model(
5153 The U-Net.
5254 """
5355 model = UNet2d (
54- in_channels = 1 ,
56+ in_channels = in_channels ,
5557 out_channels = out_channels ,
5658 initial_features = initial_features ,
5759 gain = 2 ,
@@ -183,6 +185,7 @@ def supervised_training(
183185 check : bool = False ,
184186 ignore_label : Optional [int ] = None ,
185187 label_transform : Optional [callable ] = None ,
188+ in_channels : int = 1 ,
186189 out_channels : int = 2 ,
187190 mask_channel : bool = False ,
188191 ** loader_kwargs ,
@@ -242,9 +245,9 @@ def supervised_training(
242245
243246 is_2d , _ = _determine_ndim (patch_shape )
244247 if is_2d :
245- model = get_2d_model (out_channels = out_channels )
248+ model = get_2d_model (out_channels = out_channels , in_channels = in_channels )
246249 else :
247- model = get_3d_model (out_channels = out_channels )
250+ model = get_3d_model (out_channels = out_channels , in_channels = in_channels )
248251
249252 loss , metric = None , None
250253 # No ignore label -> we can use default loss.
0 commit comments