-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
81 lines (70 loc) · 3.4 KB
/
predict.py
File metadata and controls
81 lines (70 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import cv2
import torch
import numpy as np
from training_loop import get_model
import argparse
############################
# Parameters
############################
parser = argparse.ArgumentParser(description="Train a model with given parameters.")
parser.add_argument("--model_name", type=str, default="ResUNet", help="Name of the model to use")
args = parser.parse_args()
model_name = args.model_name
PATH = os.getcwd()
model_path = f"{PATH}/saved_models/{model_name}.pth" # Path to your saved .pth model
input_dir = f"{PATH}/data/test_set_images" # Directory containing input images
output_dir = f"{PATH}/outputs/predictions" # Directory to save predictions
os.makedirs(output_dir, exist_ok=True) # Create output directory if it doesn't exist
target_size = (416, 416) # Resize to match model's input requirement
############################
# Model Loading
############################
# Define the model architecture (should match the model you trained)
model = get_model(model_name) # Change num_classes if you have a different number of classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Set the device
model.load_state_dict(torch.load(model_path, map_location=device)) # Load weights
model.to(device)
model.eval()
############################
# Preprocessing Function
############################
def preprocess_image(image_path, target_size):
"""Load and preprocess an image."""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Failed to load image: {image_path}")
image = cv2.resize(image, target_size)
image = image.astype(np.float32) / 255.0 # Normalize to [0, 1]
image = np.transpose(image, (2, 0, 1)) # HWC to CHW
image_tensor = torch.tensor(image, dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
return image_tensor
############################
# Postprocessing Function
############################
def postprocess_prediction(prediction, original_shape):
"""Postprocess the model's output to match the original image size."""
prediction = torch.sigmoid(prediction).squeeze().cpu().numpy() # Apply sigmoid and remove batch dimension
prediction = (prediction > 0.5).astype(np.uint8) # Threshold to binary mask
prediction = cv2.resize(prediction, original_shape[::-1], interpolation=cv2.INTER_NEAREST) # Resize to original shape
return prediction
############################
# Prediction Loop
############################
# Get all image files in the input directory
image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
for image_file in image_files:
image_path = os.path.join(input_dir, image_file)
output_path = os.path.join(output_dir, f"prediction_{image_file}")
# Preprocess the image
original_image = cv2.imread(image_path)
original_shape = original_image.shape[:2] # Store original dimensions
image_tensor = preprocess_image(image_path, target_size)
# Predict using the model
with torch.no_grad():
prediction = model(image_tensor)
# Postprocess the prediction
prediction_mask = postprocess_prediction(prediction, original_shape)
# Save the prediction as an image
cv2.imwrite(output_path, prediction_mask * 255) # Multiply by 255 for visibility (binary mask)
print(f"Prediction saved to {output_path}")