11import cv2
2+ import os
23import torch
34import clip
45import numpy as np
@@ -17,16 +18,17 @@ def __init__(self):
1718 self .device = "cuda" if torch .cuda .is_available () else "cpu"
1819 print (f"Using { self .device } ." )
1920 print ("Downloading model will take a while for the first time." )
21+ self .template_target_image = np .zeros ([100 , 100 , 3 ], dtype = np .uint8 )+ 100
2022 self .model , self .preprocess = clip .load ("ViT-B/32" , device = self .device )
2123
22- def search_image (self , target_image_info : dict , source_image_path : str ):
23- top_k = 3 # 最大匹配数量
24- text_alpha = 0.5 # 文本语义scale
24+ def search_image (self , target_image_info : dict , source_image_path , top_k , image_alpha , text_alpha ):
25+ top_k = top_k # 最大匹配数量
26+ image_alpha = image_alpha # 图相关系数
27+ text_alpha = text_alpha # 文本相关系数
2528 roi_list = []
2629 img_text_score = []
27- target_image_path = target_image_info .get ('path ' , '' )
30+ target_image = target_image_info .get ('img ' , self . template_target_image )
2831 target_image_desc = target_image_info .get ('desc' , '' )
29- target_image = cv2 .imread (target_image_path )
3032 source_image = cv2 .imread (source_image_path )
3133 image_infer_result = get_ui_infer (source_image_path )
3234 text = clip .tokenize ([target_image_desc ]).to (self .device )
@@ -48,24 +50,72 @@ def search_image(self, target_image_info: dict, source_image_path: str):
4850 # 图像加文本
4951 for i , source_image_feature in enumerate (source_image_features ):
5052 score = cosine_similar (target_image_features [0 ], source_image_feature )
51- img_text_score .append (score + probs [0 ][i ]* text_alpha )
53+ img_text_score .append (score * image_alpha + probs [0 ][i ]* text_alpha )
5254 score_norm = (img_text_score - np .min (img_text_score )) / (np .max (img_text_score ) - np .min (img_text_score ))
5355 top_k_ids = np .argsort (score_norm )[- top_k :]
5456 return top_k_ids , score_norm , image_infer_result
5557
58+ def get_trace_result (self , target_image_info , source_image_path , top_k = 3 , image_alpha = 1.0 , text_alpha = 0.6 ):
59+ top_k_ids , scores , infer_result = self .search_image (target_image_info , source_image_path ,
60+ top_k , image_alpha , text_alpha )
61+ cls_ids = np .zeros (len (top_k_ids ), dtype = int )
62+ boxes = [infer_result [i ]['elem_det_region' ] for i in top_k_ids ]
63+ scores = [float (scores [i ]) for i in top_k_ids ]
64+ image_show = img_show (cv2 .imread (source_image_path ), boxes , scores , cls_ids , conf = 0.6 , class_names = ['T' ])
65+ return image_show
5666
57- if __name__ == '__main__' :
67+ def video_target_track (self , video_path , target_image_info , work_path ):
68+ video_cap = cv2 .VideoCapture (video_path )
69+ _ , im = video_cap .read ()
70+ fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
71+ im_save_path = os .path .join (work_path , 'im_temp.png' )
72+ video_out_path = os .path .join (work_path , 'video_out.mp4' )
73+ out = cv2 .VideoWriter (video_out_path , fourcc , 20 , (im .shape [1 ], im .shape [0 ]))
74+ i = 0
75+ while 1 :
76+ i = i + 1
77+ if i % 2 == 0 :
78+ continue
79+ print (f"video parsing { i } " )
80+ ret , im = video_cap .read ()
81+ if ret :
82+ cv2 .imwrite (im_save_path , im )
83+ trace_result = self .get_trace_result (target_image_info , im_save_path , top_k = 1 )
84+ out .write (trace_result )
85+ else :
86+ print ("finish." )
87+ out .release ()
88+ break
89+
90+
91+ def search_target_image ():
92+ # search a target image
93+ image_alpha = 1.0
94+ text_alpha = 0.6
5895 target_image_info = {
5996 'path' : "./capture/local_images/search_icon.png" ,
6097 'desc' : "shape of magnifier with blue background"
6198 }
62- source_image_path = "./capture/image_1.png"
63- image_trace = ImageTrace ()
64- top_k_ids , scores , infer_result = image_trace .search_image (target_image_info , source_image_path )
65- # show result
99+ source_image_path = "./capture/local_images/05.png"
66100 trace_result_path = "./capture/local_images/trace_result.png"
67- cls_ids = np . zeros ( len ( top_k_ids ), dtype = int )
68- boxes = [ infer_result [ i ][ 'elem_det_region' ] for i in top_k_ids ]
69- scores = [ float ( scores [ i ]) for i in top_k_ids ]
70- image_trace_show = img_show ( cv2 . imread ( source_image_path ), boxes , scores , cls_ids , conf = 0.5 , class_names = [ 'T' ] )
101+ target_image_info [ 'img' ] = cv2 . imread ( target_image_info [ 'path' ] )
102+ image_trace = ImageTrace ()
103+ image_trace_show = image_trace . get_trace_result ( target_image_info , source_image_path ,
104+ image_alpha = image_alpha , text_alpha = text_alpha )
71105 cv2 .imwrite (trace_result_path , image_trace_show )
106+
107+
108+ def trace_target_video ():
109+ target_image_info = {
110+ 'path' : "./capture/local_images/img_play_icon.png" ,
111+ 'desc' : "picture with play button"
112+ }
113+ target_image_info ['img' ] = cv2 .imread (target_image_info ['path' ])
114+ video_path = "./capture/local_images/video.mp4"
115+ work_path = './capture/local_images'
116+ image_trace = ImageTrace ()
117+ image_trace .video_target_track (video_path , target_image_info , work_path )
118+
119+
120+ if __name__ == '__main__' :
121+ search_target_image ()
0 commit comments