Skip to content

Commit 11a4f46

Browse files
authored
Merge branch 'dev' into fix-issue-8601
2 parents 1f37d0d + 01711cf commit 11a4f46

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

monai/transforms/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,7 +2498,7 @@ def distance_transform_edt(
24982498
if return_indices:
24992499
dtype = torch.int32
25002500
if indices is None:
2501-
indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore
2501+
indices = torch.zeros((img.shape[0],) + (img.dim() - 1,) + img.shape[1:], dtype=dtype) # type: ignore
25022502
else:
25032503
if not isinstance(indices, torch.Tensor) and indices.device != img.device:
25042504
raise TypeError("indices must be a torch.Tensor on the same device as img")
@@ -2532,7 +2532,7 @@ def distance_transform_edt(
25322532
raise TypeError("distances must be a numpy.ndarray of dtype float64")
25332533
if return_indices:
25342534
if indices is None:
2535-
indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
2535+
indices = np.zeros((img_.shape[0],) + (img_.ndim - 1,) + img_.shape[1:], dtype=np.int32)
25362536
else:
25372537
if not isinstance(indices, np.ndarray):
25382538
raise TypeError("indices must be a numpy.ndarray")

0 commit comments

Comments
 (0)