Skip to content

Commit 56561d1

Browse files
committed
feat:video target trace
1 parent 080081d commit 56561d1

File tree

1 file changed

+65
-15
lines changed

1 file changed

+65
-15
lines changed

service/image_trace.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import cv2
2+
import os
23
import torch
34
import clip
45
import numpy as np
@@ -17,16 +18,17 @@ def __init__(self):
1718
self.device = "cuda" if torch.cuda.is_available() else "cpu"
1819
print(f"Using {self.device}.")
1920
print("Downloading model will take a while for the first time.")
21+
self.template_target_image = np.zeros([100, 100, 3], dtype=np.uint8)+100
2022
self.model, self.preprocess = clip.load("ViT-B/32", device=self.device)
2123

22-
def search_image(self, target_image_info: dict, source_image_path: str):
23-
top_k = 3 # 最大匹配数量
24-
text_alpha = 0.5 # 文本语义scale
24+
def search_image(self, target_image_info: dict, source_image_path, top_k, image_alpha, text_alpha):
25+
top_k = top_k # 最大匹配数量
26+
image_alpha = image_alpha # 图相关系数
27+
text_alpha = text_alpha # 文本相关系数
2528
roi_list = []
2629
img_text_score = []
27-
target_image_path = target_image_info.get('path', '')
30+
target_image = target_image_info.get('img', self.template_target_image)
2831
target_image_desc = target_image_info.get('desc', '')
29-
target_image = cv2.imread(target_image_path)
3032
source_image = cv2.imread(source_image_path)
3133
image_infer_result = get_ui_infer(source_image_path)
3234
text = clip.tokenize([target_image_desc]).to(self.device)
@@ -48,24 +50,72 @@ def search_image(self, target_image_info: dict, source_image_path: str):
4850
# 图像加文本
4951
for i, source_image_feature in enumerate(source_image_features):
5052
score = cosine_similar(target_image_features[0], source_image_feature)
51-
img_text_score.append(score + probs[0][i]*text_alpha)
53+
img_text_score.append(score*image_alpha + probs[0][i]*text_alpha)
5254
score_norm = (img_text_score - np.min(img_text_score)) / (np.max(img_text_score) - np.min(img_text_score))
5355
top_k_ids = np.argsort(score_norm)[-top_k:]
5456
return top_k_ids, score_norm, image_infer_result
5557

58+
def get_trace_result(self, target_image_info, source_image_path, top_k=3, image_alpha=1.0, text_alpha=0.6):
59+
top_k_ids, scores, infer_result = self.search_image(target_image_info, source_image_path,
60+
top_k, image_alpha, text_alpha)
61+
cls_ids = np.zeros(len(top_k_ids), dtype=int)
62+
boxes = [infer_result[i]['elem_det_region'] for i in top_k_ids]
63+
scores = [float(scores[i]) for i in top_k_ids]
64+
image_show = img_show(cv2.imread(source_image_path), boxes, scores, cls_ids, conf=0.6, class_names=['T'])
65+
return image_show
5666

57-
if __name__ == '__main__':
67+
def video_target_track(self, video_path, target_image_info, work_path):
68+
video_cap = cv2.VideoCapture(video_path)
69+
_, im = video_cap.read()
70+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
71+
im_save_path = os.path.join(work_path, 'im_temp.png')
72+
video_out_path = os.path.join(work_path, 'video_out.mp4')
73+
out = cv2.VideoWriter(video_out_path, fourcc, 20, (im.shape[1], im.shape[0]))
74+
i = 0
75+
while 1:
76+
i = i + 1
77+
if i % 2 == 0:
78+
continue
79+
print(f"video parsing {i}")
80+
ret, im = video_cap.read()
81+
if ret:
82+
cv2.imwrite(im_save_path, im)
83+
trace_result = self.get_trace_result(target_image_info, im_save_path, top_k=1)
84+
out.write(trace_result)
85+
else:
86+
print("finish.")
87+
out.release()
88+
break
89+
90+
91+
def search_target_image():
92+
# search a target image
93+
image_alpha = 1.0
94+
text_alpha = 0.6
5895
target_image_info = {
5996
'path': "./capture/local_images/search_icon.png",
6097
'desc': "shape of magnifier with blue background"
6198
}
62-
source_image_path = "./capture/image_1.png"
63-
image_trace = ImageTrace()
64-
top_k_ids, scores, infer_result = image_trace.search_image(target_image_info, source_image_path)
65-
# show result
99+
source_image_path = "./capture/local_images/05.png"
66100
trace_result_path = "./capture/local_images/trace_result.png"
67-
cls_ids = np.zeros(len(top_k_ids), dtype=int)
68-
boxes = [infer_result[i]['elem_det_region'] for i in top_k_ids]
69-
scores = [float(scores[i]) for i in top_k_ids]
70-
image_trace_show = img_show(cv2.imread(source_image_path), boxes, scores, cls_ids, conf=0.5, class_names=['T'])
101+
target_image_info['img'] = cv2.imread(target_image_info['path'])
102+
image_trace = ImageTrace()
103+
image_trace_show = image_trace.get_trace_result(target_image_info, source_image_path,
104+
image_alpha=image_alpha, text_alpha=text_alpha)
71105
cv2.imwrite(trace_result_path, image_trace_show)
106+
107+
108+
def trace_target_video():
109+
target_image_info = {
110+
'path': "./capture/local_images/img_play_icon.png",
111+
'desc': "picture with play button"
112+
}
113+
target_image_info['img'] = cv2.imread(target_image_info['path'])
114+
video_path = "./capture/local_images/video.mp4"
115+
work_path = './capture/local_images'
116+
image_trace = ImageTrace()
117+
image_trace.video_target_track(video_path, target_image_info, work_path)
118+
119+
120+
if __name__ == '__main__':
121+
search_target_image()

0 commit comments

Comments
 (0)