1+ '''
2+ Gradio demo (almost the same code as the one used in Huggingface space)
3+ '''
14import os , sys
25import cv2
6+ import time
37import gradio as gr
48import torch
59import numpy as np
@@ -20,6 +24,10 @@ def auto_download_if_needed(weight_path):
2024 if not os .path .exists ("pretrained" ):
2125 os .makedirs ("pretrained" )
2226
27+ if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth" :
28+ os .system ("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth" )
29+ os .system ("mv 4x_APISR_RRDB_GAN_generator.pth pretrained" )
30+
2331 if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth" :
2432 os .system ("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth" )
2533 os .system ("mv 4x_APISR_GRL_GAN_generator.pth pretrained" )
@@ -28,6 +36,7 @@ def auto_download_if_needed(weight_path):
2836 os .system ("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth" )
2937 os .system ("mv 2x_APISR_RRDB_GAN_generator.pth pretrained" )
3038
39+
3140
3241
3342def inference (img_path , model_name ):
@@ -41,22 +50,29 @@ def inference(img_path, model_name):
4150 auto_download_if_needed (weight_path )
4251 generator = load_grl (weight_path , scale = 4 ) # Directly use default way now
4352
53+ elif model_name == "4xRRDB" :
54+ weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
55+ auto_download_if_needed (weight_path )
56+ generator = load_rrdb (weight_path , scale = 4 ) # Directly use default way now
57+
4458 elif model_name == "2xRRDB" :
4559 weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
4660 auto_download_if_needed (weight_path )
4761 generator = load_rrdb (weight_path , scale = 2 ) # Directly use default way now
4862
4963 else :
50- raise gr .Error (error )
64+ raise gr .Error ("We don't support such Model" )
5165
5266 generator = generator .to (dtype = weight_dtype )
5367
5468
5569 # In default, we will automatically use crop to match 4x size
56- super_resolved_img = super_resolve_img (generator , img_path , output_path = None , weight_dtype = weight_dtype , crop_for_4x = True )
57- save_image (super_resolved_img , "SR_result.png" )
58- outputs = cv2 .imread ("SR_result.png" )
70+ super_resolved_img = super_resolve_img (generator , img_path , output_path = None , weight_dtype = weight_dtype , downsample_threshold = 720 , crop_for_4x = True )
71+ store_name = str (time .time ()) + ".png"
72+ save_image (super_resolved_img , store_name )
73+ outputs = cv2 .imread (store_name )
5974 outputs = cv2 .cvtColor (outputs , cv2 .COLOR_RGB2BGR )
75+ os .remove (store_name )
6076
6177 return outputs
6278
@@ -70,14 +86,18 @@ def inference(img_path, model_name):
7086
7187 MARKDOWN = \
7288 """
73- ## APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024)
74-
89+ ## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
90+
7591 [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
76-
77- If APISR is helpful for you, please help star the GitHub Repo. Thanks!
92+ APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
93+
94+ ### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
95+ ### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight.
96+
97+ If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks!
7898 """
7999
80- block = gr .Blocks ().queue ()
100+ block = gr .Blocks ().queue (max_size = 10 )
81101 with block :
82102 with gr .Row ():
83103 gr .Markdown (MARKDOWN )
@@ -87,6 +107,7 @@ def inference(img_path, model_name):
87107 model_name = gr .Dropdown (
88108 [
89109 "2xRRDB" ,
110+ "4xRRDB" ,
90111 "4xGRL"
91112 ],
92113 type = "value" ,
@@ -106,7 +127,7 @@ def inference(img_path, model_name):
106127 ["__assets__/lr_inputs/41.png" ],
107128 ["__assets__/lr_inputs/f91.jpg" ],
108129 ["__assets__/lr_inputs/image-00440.png" ],
109- ["__assets__/lr_inputs/image-00164.png " ],
130+ ["__assets__/lr_inputs/image-00164.jpg " ],
110131 ["__assets__/lr_inputs/img_eva.jpeg" ],
111132 ["__assets__/lr_inputs/naruto.jpg" ],
112133 ],
@@ -115,4 +136,4 @@ def inference(img_path, model_name):
115136
116137 run_btn .click (inference , inputs = [input_image , model_name ], outputs = [output_image ])
117138
118- block .launch ()
139+ block .launch ()
0 commit comments