File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -2498,7 +2498,7 @@ def distance_transform_edt(
24982498 if return_indices :
24992499 dtype = torch .int32
25002500 if indices is None :
2501- indices = torch .zeros ((img .dim (),) + img .shape , dtype = dtype ) # type: ignore
2501+ indices = torch .zeros ((img .shape [ 0 ],) + ( img . dim () - 1 ,) + img .shape [ 1 :] , dtype = dtype ) # type: ignore
25022502 else :
25032503 if not isinstance (indices , torch .Tensor ) and indices .device != img .device :
25042504 raise TypeError ("indices must be a torch.Tensor on the same device as img" )
@@ -2532,7 +2532,7 @@ def distance_transform_edt(
25322532 raise TypeError ("distances must be a numpy.ndarray of dtype float64" )
25332533 if return_indices :
25342534 if indices is None :
2535- indices = np .zeros ((img_ .ndim ,) + img_ .shape , dtype = np .int32 )
2535+ indices = np .zeros ((img_ .shape [ 0 ],) + ( img_ . ndim - 1 ,) + img_ .shape [ 1 :] , dtype = np .int32 )
25362536 else :
25372537 if not isinstance (indices , np .ndarray ):
25382538 raise TypeError ("indices must be a numpy.ndarray" )
You can’t perform that action at this time.
0 commit comments