Skip to content

Commit f2001f8

Browse files
committed
feat: 适配transformers==4.10.2及网络下载
1 parent b1fe6c7 commit f2001f8

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

code/chapter-8/07_image_captioning/clip_cap_base/01_train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tqdm
1818
from torch.utils.data import DataLoader
1919
from transformers import AdamW, get_linear_schedule_with_warmup
20+
# from transformers import AdamW, WarmupLinearSchedule
2021
from my_models.models import *
2122
from my_datasets.cocodataset import *
2223

@@ -36,6 +37,8 @@ def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
3637
scheduler = get_linear_schedule_with_warmup(
3738
optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
3839
)
40+
# scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=epochs * len(train_dataloader))
41+
3942
# save_config(args)
4043
for epoch in range(epochs):
4144
print(f">>> Training epoch {epoch}")

code/chapter-8/07_image_captioning/clip_cap_base/02_inference.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,26 @@
66
@brief : 模型推理
77
88
"""
9-
9+
import os
10+
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
11+
# os.environ['HF_ENDPOINT'] = "https://ai.gitee.com/huggingface"
12+
import os
13+
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
14+
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
1015
import sys
1116
from pathlib import Path
17+
1218
FILE = Path(__file__).resolve()
1319
ROOT = FILE.parents[0] # project root directory
1420
if str(ROOT) not in sys.path:
1521
sys.path.append(str(ROOT)) # add ROOT to PATH
1622

1723
# debug: windows下会报错:OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
1824
import platform
25+
1926
if platform.system() == 'Windows':
2027
import os
28+
2129
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
2230

2331
import clip
@@ -55,7 +63,7 @@ def __init__(self, path_ckpt):
5563
model = ClipCaptionPrefix(args.prefix_length, clip_length=args.prefix_length, prefix_size=512,
5664
num_layers=args.num_layers, mapping_type=args.mapping_type)
5765

58-
model.load_state_dict(torch.load(path_ckpt, map_location=torch.device("cpu")))
66+
model.load_state_dict(torch.load(path_ckpt, map_location=torch.device("cpu")), strict=False)
5967
model = model.eval()
6068
model = model.to(self.device)
6169
self.model = model
@@ -78,7 +86,8 @@ def main():
7886
# download from :提取码:mqri](https://pan.baidu.com/s/1CuTDtCeT2-nIvRG7N4iKtw)
7987
ckpt_path = r'coco_prefix-009-2023-0411.pt'
8088
path_img = r'G:\deep_learning_data\coco_2014\images\val2014'
81-
out_dir = './inference_output'
89+
# path_img = r'G:\deep_learning_data\coco_2017\images\train2017\train2017'
90+
out_dir = './inference_output2'
8291

8392
# 获取路径
8493
img_paths = []
@@ -94,7 +103,7 @@ def main():
94103
for idx, path_img in tqdm(enumerate(img_paths)):
95104
caps, pil_image = predictor.predict(path_img, False)
96105
img_bgr = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
97-
cv2.putText(img_bgr, caps, (0, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 1)
106+
cv2.putText(img_bgr, caps, (0, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
98107

99108
# 保存
100109
path_out = os.path.join(out_dir, os.path.basename(path_img))
@@ -105,9 +114,3 @@ def main():
105114

106115
if __name__ == '__main__':
107116
main()
108-
109-
110-
111-
112-
113-

code/chapter-8/07_image_captioning/clip_cap_base/my_models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix
175175
super(ClipCaptionModel, self).__init__()
176176
self.prefix_length = prefix_length
177177
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
178+
# self.gpt = GPT2LMHeadModel.from_pretrained('gpt2', force_download=True)
178179
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
179180
if mapping_type == MappingType.MLP:
180181
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,

code/chapter-8/07_image_captioning/clip_cap_base/my_utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def generate2(
122122

123123
outputs = model.gpt(inputs_embeds=generated)
124124
logits = outputs.logits
125+
# logits = outputs[0]
125126
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
126127
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
127128
cumulative_probs = torch.cumsum(

0 commit comments

Comments
 (0)