Skip to content

Commit 2c38d03

Browse files
authored
Merge pull request #101 from DtYXs/pytorch2.0_adaption
Adaption to Pytorch2.0
2 parents 0925e08 + 0d4013d commit 2c38d03

12 files changed

+32
-27
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<br><br>
1717

1818
# 新闻
19+
* 2023.5.9 Chinese-CLIP适配Pytorch2.0。
1920
* 2023.3.20 新增对比学习的[梯度累积](#gradient_accumulation)支持,可模拟更大batch size的训练效果
2021
* 2023.2.16 新增[FlashAttention](https://github.com/HazyResearch/flash-attention)支持,提升训练速度,降低显存占用,详见[flash_attention.md](flash_attention.md)
2122
* 2023.1.15 新增部署[ONNX](https://onnx.ai/)[TensorRT](https://developer.nvidia.com/tensorrt)模型支持(并提供预训练TensorRT模型),提升特征推理速度,满足部署需求,详见[deployment.md](deployment.md)

README_En.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ This is the Chinese version of CLIP. We use a large-scale Chinese image-text pai
1616
<br><br>
1717

1818
# News
19+
* 2023.5.9 Chinese-CLIP has been adapted to Pytorch2.0.
1920
* 2023.3.20 Support [gradient accumulation](#gradient-accumulation) in contrastive learning to simulate the training effect of a larger batch size.
2021
* 2023.2.16 Support [FlashAttention](https://github.com/HazyResearch/flash-attention) to improve training speed and reduce memory usage. See [flash_attention_En.md](flash_attention_En.md) for more information.
2122
* 2023.1.15 Support the conversion of Pytorch models into [ONNX](https://onnx.ai/) or [TensorRT](https://developer.nvidia.com/tensorrt) formats (and provide pretrained TensorRT models) to improve inference speed and meet deployment requirements. See [deployment_En.md](deployment_En.md) for more information.

cn_clip/training/main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def main():
4848
args = parse_args()
4949

5050
# Set distributed group
51-
args.local_device_rank = max(args.local_rank, 0)
51+
args.local_device_rank = int(os.environ["LOCAL_RANK"])
5252
torch.cuda.set_device(args.local_device_rank)
5353
args.device = torch.device("cuda", args.local_device_rank)
5454

@@ -108,7 +108,7 @@ def main():
108108

109109
if args.grad_checkpointing:
110110
assert not torch_version_str_compare_lessequal(torch.__version__, "1.8.0"), \
111-
"Currently our grad_checkpointing is not compatible with torch version <= 1.8.0."
111+
"Currently our grad_checkpointing is not compatible with torch version <= 1.8.0."
112112
model.set_grad_checkpointing()
113113
logging.info("Grad-checkpointing activated.")
114114

@@ -133,6 +133,9 @@ def main():
133133
# In other cases, set find_unused_parameters to False
134134
find_unused_parameters = torch_version_str_compare_lessequal(torch.__version__, "1.8.0")
135135
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_device_rank], find_unused_parameters=find_unused_parameters)
136+
# Have to set this when activating grad checkpointing in Pytorch >= 2.0.0
137+
if args.grad_checkpointing and not torch_version_str_compare_lessequal(torch.__version__, "1.14.0"):
138+
model._set_static_graph()
136139

137140
if args.precision == "fp16":
138141
convert_weights(model)
@@ -218,7 +221,7 @@ def main():
218221
model.load_state_dict(sd)
219222
# Restore the epoch and steps info, reload the dataset and dataloader for the resume epoch
220223
if not args.reset_data_offset:
221-
start_epoch = checkpoint["epoch"] - 1
224+
start_epoch = checkpoint["epoch"]
222225
steps = checkpoint["step"]
223226
data = get_data(args,
224227
epoch_id=start_epoch,

cn_clip/training/params.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,6 @@ def parse_args():
187187
help="enable full distributed gradient for feature gather"
188188
)
189189
# arguments for distributed training
190-
parser.add_argument(
191-
"--local_rank",
192-
type=int,
193-
default=-1,
194-
help="For distributed training: local_rank."
195-
)
196190
parser.add_argument(
197191
"--skip-aggregate",
198192
default=False,

flash_attention.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ Chinese-CLIP训练现已支持通过[FlashAttention](https://github.com/HazyRese
66

77
## 环境准备
88

9-
+ **Volta****Ampere**架构的Nvidia GPU显卡(如A100、RTX 3090、T4、RTX 2080),Nvidia各架构对应显卡型号可参见[此文档表格](https://en.wikipedia.org/wiki/CUDA#GPUs_supported)
10-
+ CUDA 11,NVCC
11-
+ **FlashAttention**:通过执行`pip install flash-attn`安装FlashAttention,可参见[FlashAttention项目仓库](https://github.com/HazyResearch/flash-attention)
9+
+ **Turing****Ampere****Ada****Hopper**架构的Nvidia GPU显卡(如H100、A100、RTX 3090、T4、RTX 2080),Nvidia各架构对应显卡型号可参见[此文档表格](https://en.wikipedia.org/wiki/CUDA#GPUs_supported)
10+
+ CUDA 11.4及以上版本。
11+
+ Pytorch 1.12及以上版本。
12+
+ **FlashAttention**:通过执行`pip install flash-attn`安装FlashAttention。
13+
14+
更多信息可参见[FlashAttention项目仓库](https://github.com/HazyResearch/flash-attention)
1215

1316
## 在Chinese-CLIP中用起来!
1417

@@ -17,7 +20,7 @@ Chinese-CLIP训练现已支持通过[FlashAttention](https://github.com/HazyRese
1720

1821
## 训练速度和显存占用对比
1922

20-
启用FlashAttention可在不影响效果的条件下为Chinese-CLIP的finetune过程显著提速以及降低显存占用。我们的实验在一台8卡A100 GPU(80GB显存)机器进行。
23+
启用FlashAttention可在不影响效果的条件下为Chinese-CLIP的finetune过程显著提速以及降低显存占用。我们的实验在一台8卡A100 GPU(80GB显存)机器进行,FlashAttention 0.2.8,Pytorch 1.10.1
2124

2225
我们分别列出finetune过程中,相同batch size下启用FlashAttention前后每个规模模型的FP16精度finetune的batch time和显存占用对比,可以看到启用FlashAttention后,训练速度有所提升,也更加节约显存。对于更大规模模型的训练速度提升和显存占用降低更为显著。
2326

@@ -31,7 +34,7 @@ Chinese-CLIP训练现已支持通过[FlashAttention](https://github.com/HazyRese
3134
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>1.710</td><td>1.680</td><td>1.02×</td>
3235
</tr>
3336
<tr align="center">
34-
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>1.477</td><td>0.960</td><td>1.54×</td>
37+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>450*8</td><td>1.477</td><td>0.960</td><td>1.54×</td>
3538
</tr>
3639
<tr align="center">
3740
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>1.293</td><td>0.785</td><td>1.65×</td>
@@ -55,7 +58,7 @@ Chinese-CLIP训练现已支持通过[FlashAttention](https://github.com/HazyRese
5558
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>79</td><td>75</td>
5659
</tr>
5760
<tr align="center">
58-
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>80</td><td>56</td>
61+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>450*8</td><td>80</td><td>56</td>
5962
</tr>
6063
<tr align="center">
6164
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>77</td><td>50</td>

flash_attention_En.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ Chinese-CLIP now supports the acceleration of training process through [FlashAtt
66

77
## Environmental Preparation
88

9-
+ Nvidia GPUs **with Volta or Ampere architecture** (such as A100, RTX 3090, T4, and RTX 2080). Please refer to [this document](https://en.wikipedia.org/wiki/CUDA#GPUs_supported) for the corresponding GPUs of each Nvidia architecture.
10-
+ CUDA 11, NVCC
11-
+ **FlashAttention**:Install FlashAttention by executing `pip install flash-attn`. Please refer to the [FlashAttention project repository](https://github.com/HazyResearch/flash-attention).
9+
+ Nvidia GPUs **with Turning, Ampere, Ada or Hopper architecture** (such as H100, A100, RTX 3090, T4, and RTX 2080). Please refer to [this document](https://en.wikipedia.org/wiki/CUDA#GPUs_supported) for the corresponding GPUs of each Nvidia architecture.
10+
+ CUDA 11.4 and above.
11+
+ PyTorch 1.12 and above.
12+
+ **FlashAttention**:Install FlashAttention by executing `pip install flash-attn`.
13+
14+
Please refer to the [FlashAttention project repository](https://github.com/HazyResearch/flash-attention) for more information.
1215

1316
## Use it in Chinese-CLIP!
1417

@@ -17,7 +20,7 @@ Applying FlashAttention to the finetune process of Chinese-CLIP is very simple,
1720

1821
## Training Speed and Memory Usage Comparison
1922

20-
Enabling FlashAttention can significantly speed up the finetune process and reduce the memory usage of Chinese-CLIP without affecting the precision. Our experiments are conducted on an 8-card A100 GPU (80GB memory) machine.
23+
Enabling FlashAttention can significantly speed up the finetune process and reduce the memory usage of Chinese-CLIP without affecting the precision. Our experiments are conducted on an 8-card A100 GPU (80GB memory) machine,FlashAttention 0.2.8,Pytorch 1.10.1.
2124

2225
We present the comparison of the batch time and memory usage of FP16 precision finetune for each scale model. The improvement in training speed and reduction in memory usage are more significant for larger models.
2326

@@ -31,7 +34,7 @@ We present the comparison of the batch time and memory usage of FP16 precision f
3134
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>1.710</td><td>1.680</td><td>1.02×</td>
3235
</tr>
3336
<tr align="center">
34-
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>1.477</td><td>0.960</td><td>1.54×</td>
37+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>450*8</td><td>1.477</td><td>0.960</td><td>1.54×</td>
3538
</tr>
3639
<tr align="center">
3740
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>1.293</td><td>0.785</td><td>1.65×</td>
@@ -55,7 +58,7 @@ We present the comparison of the batch time and memory usage of FP16 precision f
5558
<td width="120%">CN-CLIP<sub>RN50</sub></td><td>1200*8</td><td>79</td><td>75</td>
5659
</tr>
5760
<tr align="center">
58-
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>400*8</td><td>80</td><td>56</td>
61+
<td width="120%">CN-CLIP<sub>ViT-B/16</sub></td><td>450*8</td><td>80</td><td>56</td>
5962
</tr>
6063
<tr align="center">
6164
<td width="120%">CN-CLIP<sub>ViT-L/14</sub></td><td>128*8</td><td>77</td><td>50</td>

run_scripts/coco-cn_finetune_vit-b-16_rbt-base.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ text_model=RoBERTa-wwm-ext-base-chinese
5757
use_augment="--use-augment"
5858
# use_augment=""
5959

60-
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
60+
python3 -m torch.distributed.launch --use_env --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
6161
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
6262
--train-data=${train_data} \
6363
--val-data=${val_data} \

run_scripts/flickr30k_finetune_vit-b-16_rbt-base.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ text_model=RoBERTa-wwm-ext-base-chinese
5757
use_augment="--use-augment"
5858
# use_augment=""
5959

60-
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
60+
python3 -m torch.distributed.launch --use_env --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
6161
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
6262
--train-data=${train_data} \
6363
--val-data=${val_data} \

run_scripts/flickr30k_finetune_vit-b-16_rbt-base_flip.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ mask_ratio=0.5 # use flip: set mask ratio
5858
use_augment="--use-augment"
5959
# use_augment=""
6060

61-
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
61+
python3 -m torch.distributed.launch --use_env --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
6262
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
6363
--train-data=${train_data} \
6464
--val-data=${val_data} \

run_scripts/muge_finetune_vit-b-16_rbt-base.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ text_model=RoBERTa-wwm-ext-base-chinese
5757
use_augment="--use-augment"
5858
# use_augment=""
5959

60-
python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
60+
python3 -m torch.distributed.launch --use_env --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
6161
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
6262
--train-data=${train_data} \
6363
--val-data=${val_data} \

0 commit comments

Comments
 (0)