-
Notifications
You must be signed in to change notification settings - Fork 17
Description
I have original image and tampered image using Mantranet i need to detect incomming image is tampered or not
original image
https://ibb.co/23qttD5k
tampered image
https://ibb.co/9HKYHXnK
The followiing is the approach we followed:
I have trained 50 original images and 50 tampered images along with generated masks,
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import glob, os
import numpy as np
from model.ManTraNet import ManTraNet
-----------------------------
Dataset Class
-----------------------------
class TamperDataset(Dataset):
def init(self, base_dir):
"""
base_dir structure:
dataset/train/
images/forged/.jpg
images/authentic/.jpg
masks/forged/.png
masks/authentic/.png
"""
forged_imgs = sorted(glob.glob(os.path.join(base_dir, "images/forged/.jpg")))
forged_masks = sorted(glob.glob(os.path.join(base_dir, "masks/forged/.png")))
auth_imgs = sorted(glob.glob(os.path.join(base_dir, "images/authentic/*.jpg")))
auth_masks = sorted(glob.glob(os.path.join(base_dir, "masks/authentic/*.png")))
self.imgs = forged_imgs + auth_imgs
self.masks = forged_masks + auth_masks
self.transform_img = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
self.transform_mask = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
image = Image.open(self.imgs[idx]).convert("RGB")
mask = Image.open(self.masks[idx]).convert("L")
image = self.transform_img(image)
mask = self.transform_mask(mask)
# Normalize mask to binary (0 or 1)
mask = (mask > 0.5).float()
return image, mask
-----------------------------
2️⃣ Custom BCE + Dice Loss
-----------------------------
import torch.nn as nn
class BCEDiceLoss(nn.Module):
def init(self, bce_weight=0.7):
super().init()
self.bce = nn.BCEWithLogitsLoss()
self.bce_weight = bce_weight
def dice_loss(self, pred, target):
smooth = 1.0
pred = torch.sigmoid(pred)
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
dice = (2. * intersection + smooth) / (union + smooth)
return 1 - dice
def forward(self, pred, target):
bce = self.bce(pred, target)
dice = self.dice_loss(pred, target)
return self.bce_weight * bce + (1 - self.bce_weight) * dice
-----------------------------
3️⃣ Data Loader
-----------------------------
train_set = TamperDataset("./dataset/train")
train_loader = DataLoader(train_set, batch_size=2, shuffle=True, num_workers=0)
-----------------------------
4️⃣ Model Setup
-----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ManTraNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
loss_fn = BCEDiceLoss(bce_weight=0.7)
-----------------------------
5️⃣ Training Loop
-----------------------------
num_epochs = 100
for epoch in range(num_epochs):
model.train()
total_loss = 0
for img, mask in train_loader:
img, mask = img.to(device), mask.to(device)
pred = model(img)
loss = loss_fn(pred, mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")
torch.save(model.state_dict(), "mantranet_mixed.pth")
print("✅ Training complete. Model saved as mantranet_mixed.pth")
The loss is 0.72 which is not reducing further no matter the epochs are raised.
Pls confirm is this the right approach ?
Thanks
vij