Skip to content

Commit 71abc21

Browse files
committed
Docs + change norm in inference
Normalize each window instead of whole image for memory issues
1 parent efe213d commit 71abc21

File tree

2 files changed

+8
-13
lines changed

2 files changed

+8
-13
lines changed

docs/res/guides/training_module_guide.rst

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,16 @@ Function Reference
9595
======================== ================================================================================================
9696
Dice loss `Dice Loss from MONAI`_ with ``sigmoid=true``
9797
Generalized Dice loss `Generalized dice Loss from MONAI`_ with ``sigmoid=true``
98-
Dice-CE loss `Dice-CE Loss from MONAI`_ with ``sigmoid=true``
98+
Dice-CE loss `Dice-CrossEntropy Loss from MONAI`_ with ``sigmoid=true``
9999
Tversky loss `Tversky Loss from MONAI`_ with ``sigmoid=true``
100100
======================== ================================================================================================
101101

102-
103-
.. Binary cross-entropy `Binary cross entropy (BCE) loss from PyTorch`_
104-
BCE with logits `BCE loss with logits from PyTorch`_
105-
106102
.. _Dice Loss from MONAI: https://docs.monai.io/en/stable/losses.html#diceloss
107103
.. _Focal Loss from MONAI: https://docs.monai.io/en/stable/losses.html#focalloss
108104
.. _Dice-focal Loss from MONAI: https://docs.monai.io/en/stable/losses.html#dicefocalloss
109105
.. _Generalized dice Loss from MONAI: https://docs.monai.io/en/stable/losses.html#generalizeddiceloss
110-
.. _Dice-CE Loss from MONAI: https://docs.monai.io/en/stable/losses.html#diceceloss
106+
.. _Dice-CrossEntropy Loss from MONAI: https://docs.monai.io/en/stable/losses.html#diceceloss
111107
.. _Tversky Loss from MONAI: https://docs.monai.io/en/stable/losses.html#tverskyloss
112-
.. _Binary cross entropy (BCE) loss from PyTorch:
113-
.. _BCE loss with logits from PyTorch:
114108

115109
Once you are ready, press the Start button to begin training. The module will automatically load your dataset,
116110
perform data augmentation if you chose to, select a CUDA device if one is present, and train the model.

napari_cellseg3d/code_models/worker_inference.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
LogSignal,
3434
ONNXModelWrapper,
3535
QuantileNormalization,
36-
QuantileNormalizationd,
3736
RemapTensor,
3837
Threshold,
3938
WeightsDownloader,
@@ -194,7 +193,7 @@ def load_folder(self):
194193
EnsureChannelFirstd(keys=["image"]),
195194
# Orientationd(keys=["image"], axcodes="PLI"),
196195
# anisotropic_transform,
197-
QuantileNormalizationd(keys=["image"]),
196+
# QuantileNormalizationd(keys=["image"]),
198197
EnsureTyped(keys=["image"]),
199198
]
200199
)
@@ -204,7 +203,7 @@ def load_folder(self):
204203
LoadImaged(keys=["image"]),
205204
# AddChanneld(keys=["image"]), #already done
206205
EnsureChannelFirstd(keys=["image"]),
207-
QuantileNormalizationd(keys=["image"]),
206+
# QuantileNormalizationd(keys=["image"]),
208207
# Orientationd(keys=["image"], axcodes="PLI"),
209208
# anisotropic_transform,
210209
SpatialPadd(keys=["image"], spatial_size=pad),
@@ -248,7 +247,7 @@ def load_layer(self):
248247
# anisotropic_transform,
249248
AddChannel(),
250249
# SpatialPad(spatial_size=pad),
251-
QuantileNormalization(),
250+
# QuantileNormalization(),
252251
AddChannel(),
253252
EnsureType(),
254253
],
@@ -263,7 +262,7 @@ def load_layer(self):
263262
ToTensor(),
264263
# anisotropic_transform,
265264
AddChannel(),
266-
QuantileNormalization(),
265+
# QuantileNormalization(),
267266
SpatialPad(spatial_size=pad),
268267
AddChannel(),
269268
EnsureType(),
@@ -301,10 +300,12 @@ def model_output(
301300
# logger.debug(f"model : {model}")
302301
logger.debug(f"inputs shape : {inputs.shape}")
303302
logger.debug(f"inputs type : {inputs.dtype}")
303+
normalizazion = QuantileNormalization()
304304
try:
305305
# outputs = model(inputs)
306306

307307
def model_output_wrapper(inputs):
308+
inputs = normalizazion(inputs)
308309
result = model(inputs)
309310
return post_process_transforms(result)
310311

0 commit comments

Comments
 (0)