Skip to content

Commit 02d528c

Browse files
林旻佑林旻佑
authored andcommitted
Fix: correct import, English docstring, safer channel heuristic; robust y channel-last detection (refs #8366)
Signed-off-by: 林旻佑 <[email protected]>
1 parent c2b5fc4 commit 02d528c

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

monai/inferers/utils.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,37 +40,47 @@
4040

4141
def ensure_channel_first(x: torch.Tensor, spatial_ndim: Optional[int] = None) -> Tuple[torch.Tensor, int]:
4242
"""
43-
將張量標準化為 channel-first(N,C,spatial...)。
44-
回傳 (可能已轉換的張量, 原本 channel 維度:1 表示本來就在 dim=1;-1 表示本來在最後一維)。
43+
Normalize a tensor to channel-first layout (N, C, spatial...).
4544
46-
支援常見情況:
47-
- [N, C, *spatial] -> 原樣返回
48-
- [N, *spatial, C] -> 移動最後一維到 dim=1
49-
其他模糊情況則丟出 ValueError,避免悄悄算錯。
50-
"""
51-
if not isinstance(x, torch.Tensor):
52-
raise TypeError(f"expect torch.Tensor, got {type(x)}")
53-
if x.ndim < 3:
54-
raise ValueError(f"expect >=3 dims (N,C,spatial...), got shape={tuple(x.shape)}")
45+
Args:
46+
x: Tensor with shape (N, C, spatial...) or (N, spatial..., C).
47+
spatial_ndim: Number of spatial dimensions. If None, inferred as x.ndim - 2.
48+
49+
Returns:
50+
A tuple (x_cf, orig_channel_dim):
51+
- x_cf: the tensor in channel-first layout.
52+
- orig_channel_dim: 1 if input was already channel-first; -1 if the channel was last.
53+
54+
Raises:
55+
TypeError: if x is not a torch.Tensor.
56+
ValueError: if x.ndim < 3 or the channel dimension cannot be inferred unambiguously.
5557
56-
# 若未指定,估個常見的 2D/3D 空間維度數,僅用於錯誤訊息與判斷參考
58+
Notes:
59+
Uses a small-channel heuristic (<=32) typical for segmentation/classification. When ambiguous,
60+
prefers preserving the input layout or raises ValueError to avoid silent errors.
61+
"""
62+
63+
5764
if spatial_ndim is None:
58-
spatial_ndim = max(2, min(3, x.ndim - 2))
65+
spatial_ndim = x.ndim - 2
5966

60-
# 簡單啟發式:C 通常不會太大(<=512)
61-
c_first_ok = x.shape[1] <= 512
62-
c_last_ok = x.shape[-1] <= 512
67+
threshold = 32
68+
s1, sl = int(x.shape[1]), int(x.shape[-1])
6369

64-
# 優先保留 channel-first
65-
if c_first_ok and x.ndim >= 2 + spatial_ndim:
70+
if s1 <= threshold and sl > threshold:
6671
return x, 1
67-
if c_last_ok:
72+
if sl <= threshold and s1 > threshold:
6873
return x.movedim(-1, 1), -1
6974

75+
if s1 <= threshold and sl <= threshold:
76+
return x, 1
77+
7078
raise ValueError(
71-
f"cannot infer channel dim for shape={tuple(x.shape)}; "
72-
f"expected [N,C,spatial...] or [N,spatial...,C] (spatial_ndim≈{spatial_ndim})"
73-
)
79+
f"cannot infer channel dim for shape={tuple(x.shape)}; expected [N,C,spatial...] or [N,spatial...,C]; "
80+
f"both dim1={s1} and dim-1={sl} look like spatial dims"
81+
)
82+
83+
7484
def sliding_window_inference(
7585
inputs: torch.Tensor | MetaTensor,
7686
roi_size: Sequence[int] | int,

0 commit comments

Comments
 (0)