Skip to content

Commit ca15762

Browse files
authored
Enhance rescale_array (#3246)
* [DLMED] enhance rescale Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix typo Signed-off-by: Nic Ma <[email protected]>
1 parent 547830b commit ca15762

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

monai/transforms/intensity/array.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,14 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
436436
if self.minv is not None and self.maxv is not None:
437437
if self.channel_wise:
438438
out = [rescale_array(d, self.minv, self.maxv, dtype=self.dtype) for d in img]
439-
return torch.stack(out) if isinstance(img, torch.Tensor) else np.stack(out) # type: ignore
440-
return rescale_array(img, self.minv, self.maxv, dtype=self.dtype)
441-
if self.factor is not None:
442-
ret = img * (1 + self.factor)
443-
if self.dtype is not None:
444-
ret, *_ = convert_data_type(ret, dtype=self.dtype or img.dtype)
445-
return ret
446-
raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.")
439+
ret = torch.stack(out) if isinstance(img, torch.Tensor) else np.stack(out) # type: ignore
440+
else:
441+
ret = rescale_array(img, self.minv, self.maxv, dtype=self.dtype)
442+
else:
443+
ret = (img * (1 + self.factor)) if self.factor is not None else img
444+
445+
ret, *_ = convert_data_type(ret, dtype=self.dtype or img.dtype)
446+
return ret
447447

448448

449449
class RandScaleIntensity(RandomizableTransform):

monai/transforms/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ def rescale_array(
156156
) -> NdarrayOrTensor:
157157
"""
158158
Rescale the values of numpy array `arr` to be from `minv` to `maxv`.
159+
160+
Args:
161+
arr: input array to rescale.
162+
minv: minimum value of target rescaled array.
163+
maxv: maxmum value of target rescaled array.
164+
dtype: if not None, convert input array to dtype before computation.
165+
159166
"""
160167
if dtype is not None:
161168
arr, *_ = convert_data_type(arr, dtype=dtype)

0 commit comments

Comments
 (0)