Skip to content

Commit 7a1b87d

Browse files
committed
requirements, simplified torch2onnx
1 parent 3728cee commit 7a1b87d

File tree

4 files changed

+56
-86
lines changed

4 files changed

+56
-86
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.idea
2+
__pycache__
3+
*.pth
4+
*.onnx

README.md

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,40 @@
1-
# GFPGAN-onnxruntime-demo
2-
This is the onnxruntime inference code for GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior (CVPR 2021). Official code: https://github.com/TencentARC/GFPGAN
1+
# Fork of [xuanandsix/GFPGAN-onnxruntime-demo](https://github.com/xuanandsix/GFPGAN-onnxruntime-demo)
32

4-
## The following issues are addressed:
5-
1、noise = out.new_empty(b, 1, h, w).normal_() in stylegan2_clean_arch.py can‘t be supported in ONNX. I move it out the Model class, like noise = Noise[i], the Noise is a list or others which prestores generated random noise.
3+
Differences between original repository and fork:
64

7-
2、the forward function of Model is very bad, especially stylegan, so many " if else " and class be reused. Like the StyleConv " in "useself.style_convs.append StyleConv ...". So I rewrite and make it in single forward.
5+
* Compatibility with PyTorch >=2.4. (🔥)
6+
* Original pretrained models and converted ONNX models from GitHub [releases page](https://github.com/clibdev/GFPGAN-onnxruntime-demo/releases). (🔥)
7+
* Installation with [requirements.txt](requirements.txt) file.
8+
* Simplified [torch2onnx.py](torch2onnx.py) file.
9+
* The following warnings has been fixed:
10+
* FutureWarning: You are using 'torch.load' with 'weights_only=False'.
811

9-
## convert torch to onnx.
10-
```
11-
wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
12+
# Installation
1213

13-
python torch2onnx.py --src_model_path ./GFPGANv1.3.pth --dst_model_path ./GFPGANv1.3.onnx --img_size 512
14+
```shell
15+
pip install -r requirements.txt
1416
```
1517

16-
## run onnx demo.
17-
```
18-
python demo_onnx.py --model_path GFPGANv1.3.onnx --image_path ./cropped_faces/Adele_crop.png --save_path Adele_v3.jpg
19-
```
18+
# Pretrained models
2019

21-
| input | output|
22-
| :-: |:-:|
23-
|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Justin_Timberlake_crop.png" height="80%" width="80%">|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Justin_Timberlake_v2.jpg" height="80%" width="80%">|
24-
|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Julia_Roberts_crop.png" height="80%" width="80%">|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Julia_Roberts_v2.jpg" height="80%" width="80%">|
25-
|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/cropped_faces/Paris_Hilton_crop.png" height="80%" width="80%">|<img src="https://github.com/xuanandsix/GFPGAN-onnxruntime-demo/raw/main/imgs/Paris_Hilton_v2.jpg" height="80%" width="80%">|
20+
| Name | Link |
21+
|------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
22+
| GFPGANv1.2 | [PyTorch](https://github.com/clibdev/GFPGAN-onnxruntime-demo/releases/latest/download/GFPGANv1.2.pth), [ONNX](https://github.com/clibdev/GFPGAN-onnxruntime-demo/releases/latest/download/GFPGANv1.2.onnx) |
23+
| GFPGANv1.3 | [PyTorch](https://github.com/clibdev/GFPGAN-onnxruntime-demo/releases/latest/download/GFPGANv1.3.pth), [ONNX](https://github.com/clibdev/GFPGAN-onnxruntime-demo/releases/latest/download/GFPGANv1.3.onnx) |
24+
| GFPGANv1.4 | [PyTorch](https://github.com/clibdev/GFPGAN-onnxruntime-demo/releases/latest/download/GFPGANv1.4.pth), [ONNX](https://github.com/clibdev/GFPGAN-onnxruntime-demo/releases/latest/download/GFPGANv1.4.onnx) |
2625

26+
# Export to ONNX format
2727

28+
```shell
29+
python torch2onnx.py --src_model_path GFPGANv1.2.pth --dst_model_path GFPGANv1.2.onnx
30+
python torch2onnx.py --src_model_path GFPGANv1.3.pth --dst_model_path GFPGANv1.3.onnx
31+
python torch2onnx.py --src_model_path GFPGANv1.4.pth --dst_model_path GFPGANv1.4.onnx
32+
```
33+
34+
# Inference
35+
36+
```shell
37+
python demo_onnx.py --model_path GFPGANv1.2.onnx --image_path cropped_faces/Justin_Timberlake_crop.png --save_path output1.2.png
38+
python demo_onnx.py --model_path GFPGANv1.3.onnx --image_path cropped_faces/Justin_Timberlake_crop.png --save_path output1.3.png
39+
python demo_onnx.py --model_path GFPGANv1.3.onnx --image_path cropped_faces/Justin_Timberlake_crop.png --save_path output1.4.png
40+
```

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch>=2.4.0
2+
torchvision>=0.19.0
3+
opencv-python>=4.10.0
4+
onnx>=1.16.0
5+
onnxruntime>=1.19.0

torch2onnx.py

Lines changed: 16 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,22 @@
1-
# -*- coding: utf-8 -*-
2-
3-
#import cv2
4-
import numpy as np
5-
import time
61
import torch
7-
import pdb
82
from collections import OrderedDict
9-
10-
import sys
11-
sys.path.append('.')
12-
sys.path.append('./lib')
13-
import torch.nn as nn
14-
from torch.autograd import Variable
15-
import onnxruntime
16-
import timeit
17-
183
import argparse
194
from GFPGANReconsitution import GFPGAN
205

216
parser = argparse.ArgumentParser("ONNX converter")
227
parser.add_argument('--src_model_path', type=str, default=None, help='src model path')
238
parser.add_argument('--dst_model_path', type=str, default=None, help='dst model path')
24-
parser.add_argument('--img_size', type=int, default=None, help='img size')
259
args = parser.parse_args()
26-
27-
#device = torch.device('cuda')
10+
11+
# device = torch.device('cuda')
2812
model_path = args.src_model_path
2913
onnx_model_path = args.dst_model_path
30-
img_size = args.img_size
3114

32-
model = GFPGAN()#.cuda()
15+
model = GFPGAN() # .cuda()
3316

34-
x = torch.rand(1, 3, 512, 512)#.cuda()
17+
x = torch.rand(1, 3, 512, 512) # .cuda()
3518

36-
state_dict = torch.load(model_path)['params_ema']
19+
state_dict = torch.load(model_path, weights_only=True)['params_ema']
3720
new_state_dict = OrderedDict()
3821
for k, v in state_dict.items():
3922
# stylegan_decoderdotto_rgbsdot1dotmodulated_convdotbias
@@ -45,52 +28,17 @@
4528
new_state_dict[k] = v
4629
else:
4730
new_state_dict[k] = v
48-
31+
4932
model.load_state_dict(new_state_dict, strict=False)
5033
model.eval()
5134

52-
torch.onnx.export(model, x, onnx_model_path,
53-
export_params=True, opset_version=11, do_constant_folding=True,
54-
input_names = ['input'],output_names = [])
55-
56-
57-
####
58-
try:
59-
original_model = onnx.load(onnx_model_path)
60-
passes = ['fuse_bn_into_conv']
61-
optimized_model = optimizer.optimize(original_model, passes)
62-
onnx.save(optimized_model, onnx_model_path)
63-
except:
64-
print('skip optimize.')
65-
66-
####
67-
ort_session = onnxruntime.InferenceSession(onnx_model_path)
68-
for var in ort_session.get_inputs():
69-
print(var.name)
70-
for var in ort_session.get_outputs():
71-
print(var.name)
72-
_,_,input_h,input_w = ort_session.get_inputs()[0].shape
73-
t = timeit.default_timer()
74-
75-
img = np.zeros((input_h,input_w,3))
76-
77-
img = (np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) )#*self.scale
78-
79-
img = np.ascontiguousarray(img)
80-
#
81-
ort_inputs = {ort_session.get_inputs()[0].name: img}
82-
ort_outs = ort_session.run(None, ort_inputs)
83-
84-
print('onnxruntime infer time:', timeit.default_timer()-t)
85-
print(ort_outs[0].shape)
86-
87-
# python torch2onnx.py --src_model_path ./experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGAN.onnx --img_size 512
88-
89-
# 新版本
90-
91-
92-
# wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
93-
94-
# python torch2onnx.py --src_model_path ./GFPGANv1.4.pth --dst_model_path ./GFPGANv1.4.onnx --img_size 512
95-
96-
# python torch2onnx.py --src_model_path ./GFPGANCleanv1-NoCE-C2.pth --dst_model_path ./GFPGANv1.2.onnx --img_size 512
35+
torch.onnx.export(
36+
model,
37+
x,
38+
onnx_model_path,
39+
export_params=True,
40+
opset_version=11,
41+
do_constant_folding=True,
42+
input_names=['input'],
43+
output_names=[]
44+
)

0 commit comments

Comments
 (0)