-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
72 lines (60 loc) · 1.68 KB
/
predict.py
File metadata and controls
72 lines (60 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from src.model import DeepfakeModel
# --------------------
# Config
# --------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "weights/deepfake_v1.pth"
CLASS_NAMES = ["REAL", "FAKE"]
# --------------------
# Image preprocessing
# --------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# --------------------
# Load model
# --------------------
def load_model():
model = DeepfakeModel().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
return model
# --------------------
# Predict
# --------------------
def predict(image_path):
img = Image.open(image_path).convert("RGB")
tensor = transform(img).unsqueeze(0).to(DEVICE)
model = load_model()
with torch.no_grad():
logits = model(tensor)
probs = F.softmax(logits, dim=1)
conf, pred = torch.max(probs, dim=1)
label = CLASS_NAMES[pred.item()]
confidence = conf.item()
return label, confidence
# --------------------
# CLI
# --------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deepfake image inference")
parser.add_argument(
"--image",
type=str,
required=True,
help="Path to input face image"
)
args = parser.parse_args()
label, confidence = predict(args.image)
print(f"Prediction : {label}")
print(f"Confidence : {confidence:.4f}")