Skip to content

Commit 0925e08

Browse files
authored
Merge pull request #73 from DtYXs/master
Support gradient accumulation
2 parents 5bbd36a + b51f9b9 commit 0925e08

12 files changed

+133
-42
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<br><br>
1717

1818
# 新闻
19+
* 2023.3.20 新增对比学习的[梯度累积](#gradient_accumulation)支持,可模拟更大batch size的训练效果
1920
* 2023.2.16 新增[FlashAttention](https://github.com/HazyResearch/flash-attention)支持,提升训练速度,降低显存占用,详见[flash_attention.md](flash_attention.md)
2021
* 2023.1.15 新增部署[ONNX](https://onnx.ai/)[TensorRT](https://developer.nvidia.com/tensorrt)模型支持(并提供预训练TensorRT模型),提升特征推理速度,满足部署需求,详见[deployment.md](deployment.md)
2122
* 2022.12.12 新增实现[FLIP](https://arxiv.org/abs/2212.00794)训练策略,在finetune训练时可[激活使用](#FLIP)(感谢[@zwkkk](https://github.com/zwkkk)同学[贡献代码](https://github.com/OFA-Sys/Chinese-CLIP/pull/26)❤️)
@@ -345,8 +346,10 @@ bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}
345346
+ `valid-batch-size`: 验证时单机batch-size。(请保证`验证集样本总数 > batch-size * GPU数`,至少满足1个验证batch)
346347
+ `valid-step-interval``valid-epoch-interval`: 验证step/epoch频率,指定为-1时则在训练中不进行验证。
347348
+ `grad-checkpointing`: <span id="checkpointing"></span>使用[重计算策略](https://pytorch.org/docs/stable/checkpoint.html),在前向过程中不保存中间结果,以训练时间换取更小的显存开销,适用于显存不足的情况。(`store_true`参数,直接在脚本中加上`--grad-checkpointing`即可,目前要求Pytorch>1.8.0)
348-
+ `mask-ratio`: <span id="FLIP"></span>参照[FLIP](https://arxiv.org/abs/2212.00794)的策略,在finetune时可指定随机mask一定比例的图像patch,以降低显存开销、加快训练速度。默认为0.0,即不激活这一策略
349+
+ `mask-ratio`: <span id="FLIP"></span>参照[FLIP](https://arxiv.org/abs/2212.00794)的策略,在finetune时可指定随机mask一定比例的图像patch,以降低显存开销、加快训练速度。默认为0.0,即不激活这一策略
349350
+ `use-flash-attention`: 使用[FlashAttention](https://arxiv.org/abs/2205.14135),可在不影响效果的条件下为Chinese-CLIP的finetune过程显著提速以及降低显存占用。(`store_true`参数,配置好环境后,在脚本中加上`--use-flash-attention`即可,请详见[flash_attention.md](flash_attention.md)
351+
+ `accum-freq`: <span id="gradient_accumulation"></span>梯度累积频率,默认为1。指定为大于1的整数时开启对比学习梯度累积,模拟更大的batch size。如果单卡batch size为`m`,则总的batch size为`accum_freq * m * GPU数`
352+
+ `gather-with-grad`: 是否在分布式训练时进行带有完整梯度的特征gather,默认关闭。
350353
+ 输出选项
351354
+ `name`: 指定输出路径。超参日志, 训练日志以及产出ckpt均会存放至 `${DATAPATH}/experiments/${name}/`
352355
+ `save-step-frequency``save-epoch-frequency`: 存ckpt的步数或轮数间隔。

README_En.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ This is the Chinese version of CLIP. We use a large-scale Chinese image-text pai
1616
<br><br>
1717

1818
# News
19+
* 2023.3.20 Support [gradient accumulation](#gradient-accumulation) in contrastive learning to simulate the training effect of a larger batch size.
1920
* 2023.2.16 Support [FlashAttention](https://github.com/HazyResearch/flash-attention) to improve training speed and reduce memory usage. See [flash_attention_En.md](flash_attention_En.md) for more information.
2021
* 2023.1.15 Support the conversion of Pytorch models into [ONNX](https://onnx.ai/) or [TensorRT](https://developer.nvidia.com/tensorrt) formats (and provide pretrained TensorRT models) to improve inference speed and meet deployment requirements. See [deployment_En.md](deployment_En.md) for more information.
2122
* 2022.12.12 Implement [FLIP](https://arxiv.org/abs/2212.00794) strategy, which can be [activated](#FLIP) during finetuning (Thanks [@zwkkk](https://github.com/zwkkk) for [the PR](https://github.com/OFA-Sys/Chinese-CLIP/pull/26) ❤️)
@@ -348,6 +349,8 @@ The configuration for training includes:
348349
+ `grad-checkpointing`: <span id="checkpointing"></span>use [gradient checkpointing]((https://pytorch.org/docs/stable/checkpoint.html)) which does not keep the activations during forward computation, this strategy trades more computation and iteration time for less GPU memory cost. (`store_true` argument, just add `--grad-checkpointing` in the script to activate it, requires Pytorch>1.8.0)
349350
+ `mask-ratio`: <span id="FLIP"></span>use [FLIP](https://arxiv.org/abs/2212.00794) strategy which randomly masks a ratio of image patches to save GPU memory and speed up training. Default to 0.0, which disables the strategy.
350351
+ `use-flash-attention`: whether to use [FlashAttention](https://arxiv.org/abs/2205.14135), which can significantly speed up the finetune process and reduce the memory usage. (`store_true` argument, after configuring the environment, just add `--use-flash-attention` in the script to activate it, please see [flash_attention_En.md](flash_attention_En.md) for more information)
352+
+ `accum-freq`: <span id="gradient-accumulation"></span>Gradient accumulation frequency, default is 1. Specify an integer greater than 1 to enable gradient accumulation to simulate a larger batch size. if the batch size for a worker is `m`, the total batch size is `accum_freq * m * GPUs`.
353+
+ `gather-with-grad`: Whether to enable full distributed gradient for feature gather, off by default.
351354
+ Ouputs
352355
+ `name`: specified output path. Hyperparameter logs, training logs, and checkpoints will be saved at `${DATAPATH}/experiments/${name}/`.
353356
+ `save-step-frequency` and `save-epoch-frequency`: the intervals for saving checkpoints.

cn_clip/eval/zeroshot_evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def run(model, classifier, dataloader, args):
192192
model_info[k] = v
193193

194194
model = CLIP(**model_info)
195-
convert_weights(model)
195+
convert_weights(model)
196196

197197
# See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
198198
if args.precision == "amp" or args.precision == "fp32":

cn_clip/training/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ def main():
163163
)
164164
num_batches = data["train"].dataloader.num_batches
165165
if args.max_steps is not None:
166-
args.max_epochs = ceil(args.max_steps / num_batches)
166+
args.max_epochs = ceil(args.max_steps * args.accum_freq / num_batches)
167167
else:
168168
assert args.max_epochs is not None and args.max_epochs > 0
169-
args.max_steps = num_batches * args.max_epochs
169+
args.max_steps = (num_batches // args.accum_freq) * args.max_epochs
170170
total_steps = args.max_steps
171171
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
172172

cn_clip/training/params.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ def parse_args():
174174
action="store_true",
175175
help="Enable flash attention."
176176
)
177+
parser.add_argument(
178+
"--accum-freq",
179+
type=int,
180+
default=1,
181+
help="Update the model every --acum-freq steps."
182+
)
183+
parser.add_argument(
184+
"--gather-with-grad",
185+
default=False,
186+
action="store_true",
187+
help="enable full distributed gradient for feature gather"
188+
)
177189
# arguments for distributed training
178190
parser.add_argument(
179191
"--local_rank",

cn_clip/training/train.py

Lines changed: 99 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch.nn as nn
1010
from torch.cuda.amp import autocast
11+
import torch.distributed.nn
1112
import torch.distributed as dist
1213

1314
from cn_clip.clip.model import convert_state_dict
@@ -16,33 +17,45 @@
1617
def is_master(args):
1718
return args.rank == 0
1819

19-
def get_loss(model, images, texts, loss_img, loss_txt, args):
20-
image_features, text_features, logit_scale = model(images, texts, args.mask_ratio)
20+
def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features=None, accum_text_features=None, accum_idx=-1):
21+
if args.accum_freq == 1:
22+
image_features, text_features, logit_scale = model(images, texts, args.mask_ratio)
23+
else:
24+
assert accum_image_features and accum_text_features and accum_idx != -1
25+
chunk_image_features, chunk_text_features, logit_scale = model(images, texts, args.mask_ratio)
26+
image_features = torch.cat(
27+
accum_image_features[:accum_idx] + [chunk_image_features] + accum_image_features[accum_idx + 1:])
28+
text_features = torch.cat(
29+
accum_text_features[:accum_idx] + [chunk_text_features] + accum_text_features[accum_idx + 1:])
2130
logit_scale = logit_scale.mean()
2231
if args.aggregate:
2332
world_size = dist.get_world_size()
2433
rank = dist.get_rank()
2534

2635
# We gather tensors from all gpus to get more negatives to contrast with.
27-
gathered_image_features = [
28-
torch.zeros_like(image_features) for _ in range(world_size)
29-
]
30-
gathered_text_features = [
31-
torch.zeros_like(text_features) for _ in range(world_size)
32-
]
33-
dist.all_gather(gathered_image_features, image_features)
34-
dist.all_gather(gathered_text_features, text_features)
35-
36-
all_image_features = torch.cat(
37-
[image_features]
38-
+ gathered_image_features[:rank]
39-
+ gathered_image_features[rank + 1 :]
40-
)
41-
all_text_features = torch.cat(
42-
[text_features]
43-
+ gathered_text_features[:rank]
44-
+ gathered_text_features[rank + 1 :]
45-
)
36+
if args.gather_with_grad:
37+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
38+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
39+
else:
40+
gathered_image_features = [
41+
torch.zeros_like(image_features) for _ in range(world_size)
42+
]
43+
gathered_text_features = [
44+
torch.zeros_like(text_features) for _ in range(world_size)
45+
]
46+
dist.all_gather(gathered_image_features, image_features)
47+
dist.all_gather(gathered_text_features, text_features)
48+
49+
all_image_features = torch.cat(
50+
[image_features]
51+
+ gathered_image_features[:rank]
52+
+ gathered_image_features[rank + 1 :]
53+
)
54+
all_text_features = torch.cat(
55+
[text_features]
56+
+ gathered_text_features[:rank]
57+
+ gathered_text_features[rank + 1 :]
58+
)
4659

4760
# this is needed to send gradients back everywhere.
4861
logits_per_image = logit_scale * all_image_features @ all_text_features.t()
@@ -94,17 +107,22 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
94107
if sampler is not None:
95108
sampler.set_epoch(epoch)
96109

97-
num_batches_per_epoch = dataloader.num_batches
110+
num_steps_per_epoch = dataloader.num_batches // args.accum_freq
98111
data_iter = iter(dataloader)
99112

113+
if args.accum_freq > 1:
114+
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
115+
100116
end = time.time()
101117
epoch_trained_steps = 0
102-
for i in range(global_trained_steps - num_batches_per_epoch * epoch, num_batches_per_epoch):
118+
for i in range(0, dataloader.num_batches):
103119
batch = next(data_iter)
104-
step = num_batches_per_epoch * epoch + i
120+
121+
i_accum = i // args.accum_freq
122+
step = num_steps_per_epoch * epoch + i_accum
105123
# reach the args.max_steps, exit training:
106124
if step >= args.max_steps:
107-
logging.info("Stopping training due to step {} has reached max_steps {}".format(step, args.max_steps))
125+
logging.info("Stopping training due to step {} has reached max_steps {}".format(step, args.max_steps // args.accum_freq))
108126
return epoch_trained_steps
109127
scheduler(step)
110128

@@ -120,18 +138,60 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
120138

121139
m = model.module
122140

123-
# with automatic mixed precision.
124-
if args.precision == "amp":
125-
with autocast():
141+
if args.accum_freq == 1:
142+
# with automatic mixed precision.
143+
if args.precision == "amp":
144+
with autocast():
145+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
146+
scaler.scale(total_loss).backward()
147+
scaler.step(optimizer)
148+
scaler.update()
149+
150+
else:
126151
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
127-
scaler.scale(total_loss).backward()
152+
total_loss.backward()
153+
optimizer.step()
154+
else:
155+
# First, cache the features without any gradient tracking.
156+
with torch.no_grad():
157+
with autocast(enabled=(args.precision == "amp")):
158+
chunk_image_features, chunk_text_features, _ = model(images, texts)
159+
accum_image_features.append(chunk_image_features)
160+
accum_text_features.append(chunk_text_features)
161+
162+
accum_images.append(images)
163+
accum_texts.append(texts)
164+
165+
# If (i + 1) % accum_freq is not zero, move on to the next batch.
166+
if ((i + 1) % args.accum_freq) > 0:
167+
# FIXME this makes data time logging unreliable when accumulating
168+
continue
169+
170+
# Now, ready to take gradients for the last accum_freq batches.
171+
# Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
172+
# Call backwards each time, but only step optimizer at the end.
173+
optimizer.zero_grad()
174+
for j in range(args.accum_freq):
175+
images = accum_images[j]
176+
texts = accum_texts[j]
177+
with autocast(enabled=(args.precision == "amp")):
178+
# `total_loss` and `acc` are coarsely sampled, taking only the last result in the loop.
179+
# Although each result should be the same in theory, it will be slightly different in practice
180+
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j)
181+
if args.precision == "amp":
182+
scaler.scale(total_loss).backward()
183+
else:
184+
total_loss.backward()
185+
186+
if args.precision == "amp":
128187
scaler.step(optimizer)
129-
scaler.update()
188+
scaler.update()
189+
else:
190+
optimizer.step()
130191

131-
else:
132-
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
133-
total_loss.backward()
134-
optimizer.step()
192+
# reset gradient accum, if enabled
193+
if args.accum_freq > 1:
194+
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
135195

136196
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
137197
m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)
@@ -142,10 +202,11 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
142202
epoch_trained_steps += 1
143203

144204
if is_master(args) and ((step + 1) % args.log_interval) == 0:
145-
num_samples = (i + 1) * len(images) * args.world_size
205+
batch_size = len(images) * args.accum_freq
206+
num_samples = (i_accum + 1) * batch_size * args.world_size
146207
samples_per_epoch = dataloader.num_samples
147-
percent_complete = 100.0 * (i + 1) / num_batches_per_epoch
148-
208+
percent_complete = 100.0 * (i_accum + 1) / num_steps_per_epoch
209+
149210
logging.info(
150211
f"Global Steps: {step + 1}/{args.max_steps} | " +
151212
f"Train Epoch: {epoch + 1} [{num_samples}/{samples_per_epoch} ({percent_complete:.0f}%)] | " +
@@ -156,7 +217,7 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
156217
f"Batch Time: {batch_time:.3f}s | " +
157218
f"LR: {optimizer.param_groups[0]['lr']:5f} | " +
158219
f"logit_scale: {m.logit_scale.data:.3f} | " +
159-
f"Global Batch Size: {len(images) * args.world_size}"
220+
f"Global Batch Size: {batch_size * args.world_size}"
160221
)
161222

162223
if args.val_data is not None and args.valid_step_interval is not None and ((step + 1) % args.valid_step_interval) == 0:

run_scripts/coco-cn_finetune_vit-b-16_rbt-base.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ context_length=52
4646
warmup=6
4747
batch_size=1024
4848
valid_batch_size=128
49+
accum_freq=1
4950
lr=3e-5
5051
wd=0.001
5152
max_epochs=20
@@ -75,6 +76,7 @@ python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=$
7576
--valid-batch-size=${valid_batch_size} \
7677
--valid-step-interval=${valid_step_interval} \
7778
--valid-epoch-interval=${valid_epoch_interval} \
79+
--accum-freq=${accum_freq} \
7880
--lr=${lr} \
7981
--wd=${wd} \
8082
--max-epochs=${max_epochs} \

run_scripts/flickr30k_finetune_vit-b-16_rbt-base.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ context_length=52
4646
warmup=100
4747
batch_size=128
4848
valid_batch_size=128
49+
accum_freq=1
4950
lr=5e-5
5051
wd=0.001
5152
max_epochs=3 # or you can alternatively specify --max-steps
@@ -75,6 +76,7 @@ python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=$
7576
--valid-batch-size=${valid_batch_size} \
7677
--valid-step-interval=${valid_step_interval} \
7778
--valid-epoch-interval=${valid_epoch_interval} \
79+
--accum-freq=${accum_freq} \
7880
--lr=${lr} \
7981
--wd=${wd} \
8082
--max-epochs=${max_epochs} \

run_scripts/flickr30k_finetune_vit-b-16_rbt-base_flip.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ context_length=52
4646
warmup=100
4747
batch_size=128
4848
valid_batch_size=128
49+
accum_freq=1
4950
lr=5e-5
5051
wd=0.001
5152
max_epochs=3 # or you can alternatively specify --max-steps
@@ -76,6 +77,7 @@ python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=$
7677
--valid-batch-size=${valid_batch_size} \
7778
--valid-step-interval=${valid_step_interval} \
7879
--valid-epoch-interval=${valid_epoch_interval} \
80+
--accum-freq=${accum_freq} \
7981
--lr=${lr} \
8082
--wd=${wd} \
8183
--max-epochs=${max_epochs} \

run_scripts/muge_finetune_vit-b-16_rbt-base.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ context_length=52
4646
warmup=100
4747
batch_size=128
4848
valid_batch_size=128
49+
accum_freq=1
4950
lr=5e-5
5051
wd=0.001
5152
max_epochs=3 # or you can alternatively specify --max-steps
@@ -75,6 +76,7 @@ python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=$
7576
--valid-batch-size=${valid_batch_size} \
7677
--valid-step-interval=${valid_step_interval} \
7778
--valid-epoch-interval=${valid_epoch_interval} \
79+
--accum-freq=${accum_freq} \
7880
--lr=${lr} \
7981
--wd=${wd} \
8082
--max-epochs=${max_epochs} \

0 commit comments

Comments
 (0)