forked from IQTLabs/DeepFakeDetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
98 lines (82 loc) · 3.87 KB
/
train_model.py
File metadata and controls
98 lines (82 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import yaml
import pandas as pd
import argparse
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from dfdet import *
parser = parser = argparse.ArgumentParser(description='Batch inference script')
parser.add_argument('--gpu', dest='gpu', default=0,
type=int, help='Target gpu')
parser.add_argument('--config', dest='config',
default='./config_files/training.yaml', type=str,
help='Config file with paths and MTCNN set-up')
parser.add_argument('--verbose', dest='verbose', default=False,
type=bool, help='Verbose switch')
parser.add_argument('--load', dest='chpt', default=None, type=str,
help='Checkpoint to resume training')
def train_test_split(df, fraction=0.8, random_state=200):
df = df[df['frames'] >= 30]
train = pd.concat([df[df['label'] == 1].sample(frac=fraction,
random_state=random_state),
df[df['label'] == 0].sample(frac=fraction,
random_state=random_state)])
test = df.drop(train.index)
return train.reindex(), test.reindex()
def paired_split(df, fraction=0.8, random_state=200, nframes=30):
df = df[df['nframes'] >= nframes]
all_df = pd.read_csv('./training_metadata.json')
df.loc[df['label'] == 'REAL', 'label'] = 0
df.loc[df['label'] == 'FAKE', 'label'] = 1
test_fake = df[df['label'] == 1].sample(frac=1.0-fraction,
random_state=random_state)
originals = all_df.loc[all_df.File.isin(
list(test_fake['File'])) == True, 'original']
test_real = df[df.File.isin(list(originals)) == True]
test = pd.concat([test_real, test_fake])
train = df.drop(test.index)
return train.reindex(), test.reindex()
if __name__ == '__main__':
args = parser.parse_args()
with open(args.config) as f:
config = yaml.load(f)
df = pd.read_csv('{}/faces_metadata.csv'.format(config['data_path']))
if bool(config['paired_split']):
train, test = paired_split(df, config['training_fraction'],
config['frames'])
else:
train, test = train_test_split(df, config['training_fraction'])
trainset = DFDC_Dataset(
df=train, size=config['size'], mean=config['mean'], std=config['std'],
augment=config['augment'], frames=config['frames'],
stochastic=config['stochastic'])
trainloader = DataLoader(
trainset, batch_size=config['batch_size'], shuffle=True,
num_workers=16)
#
testset = DFDC_Dataset(
df=test, size=config['size'], mean=config['mean'], std=config['std'],
augment=False, frames=config['frames'],
stochastic=config['stochastic'])
testloader = DataLoader(
testset, batch_size=config['batch_size'], shuffle=False,
num_workers=16)
model = ConvLSTM(
num_classes=1, lstm_layers=config['lstm_layers'],
attention=config['attention'], encoder=config['encoder'],
calibrating=False, fine_tune=config['fine_tune'])
if args.chpt is not None:
print('Loading file: {}'.format(args.chpt))
chpt_file = torch.load(args.chpt)
model.load_state_dict(chpt_file['model'])
optim_, sched_ = CreateOptim(model.parameters(), lr=float(config['lr']),
weight_decay=float(config['weight_decay']),
threshold=0.0001, factor=0.7)
losses = []
averages = []
train_dfd(model=model, dataloader=trainloader, testloader=testloader,
optim=optim_, scheduler=sched_, criterion=nn.BCELoss(),
losses=losses, averages=averages, n_epochs=config['n_epochs'],
e_saves=config['e_saves'], device='cuda:{}'.format(args.gpu),
verbose=args.verbose)