1+ #!/usr/bin/env python
2+ # -*- coding: utf-8 -*-
3+ import os
4+ import sys
5+ import argparse
6+ import cv2
7+ import numpy as np
8+ import timeit
9+ import onnxruntime
10+
11+ class GFPGANFaceAugment :
12+ def __init__ (self , model_path , use_gpu = False ):
13+ self .ort_session = onnxruntime .InferenceSession (model_path )
14+ self .net_input_name = self .ort_session .get_inputs ()[0 ].name
15+ _ ,self .net_input_channels ,self .net_input_height ,self .net_input_width = self .ort_session .get_inputs ()[0 ].shape
16+ self .net_output_count = len (self .ort_session .get_outputs ())
17+ self .face_size = 512
18+ self .face_template = np .array ([[192 , 240 ], [319 , 240 ], [257 , 371 ]]) * (self .face_size / 512.0 )
19+ self .upscale_factor = 2
20+ self .affine = False
21+ self .affine_matrix = None
22+ def pre_process (self , img ):
23+ img = cv2 .resize (img , (int (img .shape [1 ] / 2 ), int (img .shape [0 ] / 2 )))
24+ img = cv2 .resize (img , (self .face_size , self .face_size ))
25+ img = img / 255.0
26+ img = img .astype ('float32' )
27+ img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
28+ img [:,:,0 ] = (img [:,:,0 ]- 0.5 )/ 0.5
29+ img [:,:,1 ] = (img [:,:,1 ]- 0.5 )/ 0.5
30+ img [:,:,2 ] = (img [:,:,2 ]- 0.5 )/ 0.5
31+ img = np .float32 (img [np .newaxis ,:,:,:])
32+ img = img .transpose (0 , 3 , 1 , 2 )
33+ return img
34+ def post_process (self , output , height , width ):
35+ output = output .clip (- 1 ,1 )
36+ output = (output + 1 ) / 2
37+ output = output .transpose (1 , 2 , 0 )
38+ output = cv2 .cvtColor (output , cv2 .COLOR_RGB2BGR )
39+ output = (output * 255.0 ).round ()
40+ if self .affine :
41+ inverse_affine = cv2 .invertAffineTransform (self .affine_matrix )
42+ inverse_affine *= self .upscale_factor
43+ if self .upscale_factor > 1 :
44+ extra_offset = 0.5 * self .upscale_factor
45+ else :
46+ extra_offset = 0
47+ inverse_affine [:, 2 ] += extra_offset
48+ inv_restored = cv2 .warpAffine (output , inverse_affine , (width , height ))
49+ mask = np .ones ((self .face_size , self .face_size ), dtype = np .float32 )
50+ inv_mask = cv2 .warpAffine (mask , inverse_affine , (width , height ))
51+ inv_mask_erosion = cv2 .erode (
52+ inv_mask , np .ones ((int (2 * self .upscale_factor ), int (2 * self .upscale_factor )), np .uint8 ))
53+ pasted_face = inv_mask_erosion [:, :, None ] * inv_restored
54+ total_face_area = np .sum (inv_mask_erosion )
55+ # compute the fusion edge based on the area of face
56+ w_edge = int (total_face_area ** 0.5 ) // 20
57+ erosion_radius = w_edge * 2
58+ inv_mask_center = cv2 .erode (inv_mask_erosion , np .ones ((erosion_radius , erosion_radius ), np .uint8 ))
59+ blur_size = w_edge * 2
60+ inv_soft_mask = cv2 .GaussianBlur (inv_mask_center , (blur_size + 1 , blur_size + 1 ), 0 )
61+ inv_soft_mask = inv_soft_mask [:, :, None ]
62+ output = pasted_face
63+ else :
64+ inv_soft_mask = np .ones ((height , width , 1 ), dtype = np .float32 )
65+ output = cv2 .resize (output , (width , height ))
66+ return output , inv_soft_mask
67+
68+ def forward (self , img ):
69+ height , width = img .shape [0 ], img .shape [1 ]
70+ img = self .pre_process (img )
71+ t = timeit .default_timer ()
72+ ort_inputs = {self .ort_session .get_inputs ()[0 ].name : img }
73+ ort_outs = self .ort_session .run (None , ort_inputs )
74+ output = ort_outs [0 ][0 ]
75+ output , inv_soft_mask = self .post_process (output , height , width )
76+ print ('infer time:' ,timeit .default_timer ()- t )
77+ output = output .astype (np .uint8 )
78+ return output , inv_soft_mask
79+
80+ if __name__ == "__main__" :
81+ parser = argparse .ArgumentParser ("onnxruntime demo" )
82+ parser .add_argument ('--model_path' , type = str , default = None , help = 'model path' )
83+ parser .add_argument ('--image_path' , type = str , default = None , help = 'input image path' )
84+ parser .add_argument ('--save_path' , type = str , default = "output.jpg" , help = 'output image path' )
85+ args = parser .parse_args ()
86+
87+ faceaugment = GFPGANFaceAugment (model_path = args .model_path )
88+ image = cv2 .imread (args .image_path , 1 )
89+ output , _ = faceaugment .forward (image )
90+ cv2 .imwrite (args .save_path , output )
91+
92+ # python demo_onnx.py --model_path GFPGANv1.4.onnx --image_path ./cropped_faces/Adele_crop.png
93+
94+
95+ # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v2.jpg
96+ # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Julia_Roberts_crop.png --save_path Julia_Roberts_v2.jpg
97+ # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Justin_Timberlake_crop.png --save_path Justin_Timberlake_v2.jpg
98+ # python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path ./cropped_faces/Paris_Hilton_crop.png --save_path Paris_Hilton_v2.jpg
0 commit comments