-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
105 lines (84 loc) · 5.01 KB
/
train.py
File metadata and controls
105 lines (84 loc) · 5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import math
import torch
from torch.cuda.amp import autocast
import utils
def mlm(text, text_input, tokenizer, device, mask_generator, config):
text_masked = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'],
return_tensors="pt").to(device)
text_ids_masked = text_masked.input_ids
masked_pos = torch.empty((text_ids_masked.shape[0], config['max_masks']), dtype=torch.int64, device=device)
masked_ids = torch.empty((text_ids_masked.shape[0], config['max_masks']), dtype=torch.long, device=device)
for index, text_id in enumerate(text_ids_masked):
text_ids_masked_, masked_pos_ = mask_generator(text_id)
masked_ids_ = [text_input.input_ids[index][p].item() for p in masked_pos_]
n_pad = config['max_masks'] - len(masked_ids_)
masked_pos_ = masked_pos_ + [0] * n_pad
masked_pos_ = torch.tensor(masked_pos_, dtype=torch.int64).to(device)
masked_ids_ = masked_ids_ + [-100] * n_pad
masked_ids_ = torch.tensor(masked_ids_, dtype=torch.long).to(device)
masked_pos[index] = masked_pos_
masked_ids[index] = masked_ids_
return text_ids_masked, masked_pos, masked_ids
def train_model(model, data_loader, optimizer, scaler, tokenizer, epoch, device, scheduler, config, mask_generator):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('loss_itc', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
metric_logger.add_meter('loss_mlm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
header = 'Train Epoch: [{}]'.format(epoch)
print_freq = 50
for i, (image, text, text_eda, idx, pose, hard_i, hard_i_pose, hard_caption) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
image = image.to(device, non_blocking=True)
text_input = tokenizer(text, padding='max_length', truncation=True, max_length=config['max_tokens'],
return_tensors="pt").to(device)
text_ids_masked, masked_pos, masked_ids = mlm(text, text_input, tokenizer, device, mask_generator, config)
text_input_eda = tokenizer(text_eda, padding='max_length', truncation=True, max_length=config['max_tokens'],
return_tensors="pt").to(device)
text_ids_eda = text_input_eda.input_ids
text_atts_eda = text_input_eda.attention_mask
idx = idx.to(device, non_blocking=True)
if config.get('be_hard', False):
hard_i = hard_i.to(device, non_blocking=True)
if config.get('be_pose_img', False):
hard_i_pose = hard_i_pose.to(device, non_blocking=True)
else:
hard_i_pose = None
hard_text_input = tokenizer(hard_caption, padding='max_length', truncation=True,
max_length=config['max_tokens'], return_tensors="pt").to(device)
hard_text_ids = hard_text_input.input_ids
hard_text_atts = hard_text_input.attention_mask
else:
hard_i, hard_i_pose, hard_text_ids, hard_text_atts = None, None, None, None
if config.get('be_pose_img', False):
pose = pose.to(device, non_blocking=True)
else:
pose = None
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
loss_itc, loss_itm, loss_mlm = \
model(image, text_input.input_ids, text_input.attention_mask,
text_ids_masked=text_ids_masked, masked_pos=masked_pos, masked_ids=masked_ids,
idx=idx, text_ids_eda=text_ids_eda, text_atts_eda=text_atts_eda,
pose=pose, hard_i=hard_i, hard_i_pose=hard_i_pose,
hard_text_ids=hard_text_ids, hard_text_atts=hard_text_atts,
)
loss = loss_itc + loss_itm + loss_mlm
scaler.scale(loss).backward()
scaler.step(optimizer)
scale = scaler.get_scale()
scaler.update()
skip_lr_sched = (scale > scaler.get_scale())
if not skip_lr_sched:
scheduler.step()
optimizer.zero_grad()
metric_logger.update(loss_itc=loss_itc.item())
metric_logger.update(loss_itm=loss_itm.item())
metric_logger.update(loss_mlm=loss_mlm.item())
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger.global_avg())
return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}