-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
224 lines (175 loc) · 8.59 KB
/
predict.py
File metadata and controls
224 lines (175 loc) · 8.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import argparse
import logging
import os
import glob
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from utils.data_loading import BasicDataset
# 1. 引入 SwinUNETR
from monai.networks.nets import SwinUNETR
from utils.utils import plot_img_and_mask
from skimage import morphology
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.3,
n_classes=1): # 新增参数:显式传入类别数
net.eval()
# ==========================================
# 2. 修改预处理:强制 Resize 到 512x512
# ==========================================
# Swin Transformer 对输入尺寸极度敏感,必须是训练时的 (512, 512)
target_size = (512, 512)
# 备份原始尺寸,以便最后还原
original_w, original_h = full_img.size
# 强制缩放
img = full_img.resize(target_size, resample=Image.BICUBIC)
# 转 Numpy -> Tensor (复用 BasicDataset 的逻辑,但跳过缩放)
# 这里手动处理一下归一化,因为 BasicDataset.preprocess 会处理 scale
img = np.asarray(img)
if img.ndim == 2:
img = img[np.newaxis, ...]
else:
img = img.transpose((2, 0, 1))
img = img / 255.0
img = torch.from_numpy(img).unsqueeze(0).to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
# ==========================================
# 3. 后处理:插值回原始尺寸
# ==========================================
# output 是 (1, C, 512, 512),我们要把它拉伸回 (original_h, original_w)
# 注意:F.interpolate 接受 (N, C, H, W),size 参数是 (H, W)
if output.shape[-2:] != (original_h, original_w):
output = F.interpolate(output, size=(original_h, original_w), mode='bilinear', align_corners=False)
# 生成 Mask
if n_classes > 1:
mask = output.argmax(dim=1)
else:
mask = torch.sigmoid(output) > out_threshold
return mask[0].long().squeeze().cpu().numpy()
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images')
parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
help='Specify the file in which the model is stored')
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
parser.add_argument('--viz', '-v', action='store_true',
help='Visualize the images as they are processed')
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
parser.add_argument('--mask-threshold', '-t', type=float, default=0.3, # 默认建议 0.3
help='Minimum probability value to consider a mask pixel white')
parser.add_argument('--scale', '-s', type=float, default=1.0, # Swin 不需要 scale 参数,设为 1.0 即可
help='Scale factor for the input images')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=1, help='Number of classes')
parser.add_argument('--fov-dir', type=str, default=None,
help='Directory containing FOV masks to remove boundary artifacts')
return parser.parse_args()
def get_output_filenames(args):
def _generate_name(fn):
return f'{os.path.splitext(fn)[0]}_OUT.png'
return args.output or list(map(_generate_name, args.input))
def mask_to_image(mask: np.ndarray, mask_values):
if isinstance(mask_values[0], list):
out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
elif mask_values == [0, 1]:
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
else:
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
if mask.ndim == 3:
mask = np.argmax(mask, axis=0)
for i, v in enumerate(mask_values):
out[mask == i] = v
if mask.max() <= 1:
# 如果是 0/1 矩阵,拉伸到 0/255
mask = mask * 255
return Image.fromarray(mask.astype(np.uint8))
def post_process(mask_pred, min_size=64):
"""
后处理:移除小于 min_size 像素的孤立噪点
mask_pred: 0/1 的二值矩阵 (numpy array)
"""
mask_bool = mask_pred > 0
cleaned = morphology.remove_small_objects(mask_bool, min_size=min_size, connectivity=1)
return cleaned.astype(np.uint8) * 255
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
in_files = args.input
out_files = get_output_filenames(args)
# ==========================================================
# 4. 核心修改:实例化 SwinUNETR
# ==========================================================
# 必须与 train.py 中的参数完全一致
net = SwinUNETR(
img_size=(512, 512),
in_channels=3,
out_channels=args.classes,
feature_size=96,
depths=(2, 2, 6, 2), # <--- 必须加上,对齐深度
num_heads=(3, 6, 12, 24), # <--- 必须加上,对齐注意力头数
use_checkpoint=True,
spatial_dims=2
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')
net.to(device=device)
# 加载权重
state_dict = torch.load(args.model, map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
# 容错处理:如果权重里还有 mask_values 导致不匹配,提前删除
if 'mask_values' in state_dict:
del state_dict['mask_values']
net.load_state_dict(state_dict)
logging.info('Model loaded!')
for i, filename in enumerate(in_files):
logging.info(f'Predicting image {filename} ...')
img = Image.open(filename)
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device,
n_classes=args.classes) # 传入类别数
# ==========================================================
# 🚀 新增:应用 FOV 视野掩膜 (物理消除边缘假阳性)
# ==========================================================
if args.fov_dir:
base_name = os.path.basename(filename)
name_without_ext = os.path.splitext(base_name)[0]
# DRIVE 数据集通常以 01, 02 等两位数字开头。提取前两位作为 ID 进行模糊匹配
file_id = name_without_ext[:2]
search_pattern = os.path.join(args.fov_dir, f"{file_id}*.*")
fov_files = glob.glob(search_pattern)
if fov_files:
fov_mask_path = fov_files[0]
# 读取 FOV Mask 并转为灰度图
fov_img = Image.open(fov_mask_path).convert('L')
# 严谨的维度对齐:确保掩膜尺寸与预测结果完全一致
if fov_img.size != (mask.shape[1], mask.shape[0]):
fov_img = fov_img.resize((mask.shape[1], mask.shape[0]), Image.NEAREST)
fov_np = np.array(fov_img)
# 二值化 (DRIVE FOV mask 中间是白色255,外部是黑色0)
fov_bool = fov_np > 128
# 核心数学操作:按位与乘法
mask = mask * fov_bool
logging.info(f'Applied FOV mask: {os.path.basename(fov_mask_path)}')
else:
logging.warning(f'No FOV mask found for {filename} in {args.fov_dir}')
# ==========================================================
result_cleaned = post_process(mask, min_size=100)
result_cleaned = post_process(mask, min_size=100)
if not args.no_save:
out_filename = out_files[i]
result = mask_to_image(result_cleaned, mask_values)
result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')
if args.viz:
logging.info(f'Visualizing results for image {filename}, close to continue...')
plot_img_and_mask(img, mask)