1-
2- """This file is modifed from synthesize.py. The goal is to return a generator which output an image in range [0., 1.]"""
3-
4- import os
5- import argparse
6- import subprocess
7- from tqdm import tqdm
8- import numpy as np
9-
10- import torch
11- from torchvision .utils import save_image
12-
13- from .models import MODEL_ZOO
14- from .models import build_generator , build_discriminator
15- from .utils .misc import bool_parser
16- from .utils .visualizer import HtmlPageVisualizer
17-
18- def postprocess (images ):
19- """change the range from [-1, 1] to [0., 1.]"""
20- images = torch .clamp ((images + 1. ) / 2. , 0. , 1. )
21- return images
22-
23- def get_genforce (model_name , device , checkpoint_dir , use_discri = True , use_w_space = True , use_z_plus_space = False , repeat_w = True ):
24-
25- trunc_psi = 0.7
26- trunc_layers = 8
27-
28- if model_name not in MODEL_ZOO :
29- raise RuntimeError (f'model name `{ model_name } ` is not in model zoo' )
30- model_config = MODEL_ZOO [model_name ].copy ()
31- url = model_config .pop ('url' )
32-
33- print (f'Building generator for model `{ model_name } `' )
34- if model_name .startswith ('stylegan' ):
35- generator = build_generator (** model_config , repeat_w = repeat_w )
36- else :
37- generator = build_generator (** model_config )
38- synthesis_kwargs = dict (trunc_psi = trunc_psi ,
39- trunc_layers = trunc_layers )
40-
41- # Build discriminator
42- if use_discri :
43- print (f'Building discriminator for model `{ model_name } ` ...' )
44- discriminator = build_discriminator (** model_config )
45- else :
46- discriminator = None
47-
48- # load checkpoints
49- os .makedirs (os .path .join (checkpoint_dir , 'genforce' ), exist_ok = True )
50- ckpt_path = os .path .join (checkpoint_dir , 'genforce' , f'{ model_name } .pth' )
51-
52- if not os .path .exists (ckpt_path ):
53- print (f'Download checkpoint { model_name } from { url } ...' )
54- subprocess .call (['wget' , '--quiet' , '-O' , ckpt_path , url ])
55-
56- checkpoint = torch .load (ckpt_path )
57-
58- if 'generator_smooth' in checkpoint :
59- generator .load_state_dict (checkpoint ['generator_smooth' ])
60- else :
61- generator .load_state_dict (checkpoint ['generator' ])
62- generator = generator .to (device )
63- generator .eval ()
64- if use_discri :
65- discriminator .load_state_dict (checkpoint ['discriminator' ])
66- discriminator = discriminator .to (device )
67- discriminator .eval ()
68- print ('Finish loading checkpoint.' )
69-
70- def fake_generator (code ):
71- # Sample and synthesize.
72- # print(f'Synthesizing {args.num} samples ...')
73- # code = torch.randn(args.batch_size, generator.z_space_dim).cuda()
74- if use_z_plus_space :
75- code = generator .mapping (code )['w' ]
76- code = code .view (- 1 , generator .num_layers , generator .w_space_dim )
77- images = generator (code , ** synthesis_kwargs , use_w_space = use_w_space )['image' ]
78- images = postprocess (images )
79- # save_image(images, os.path.join(work_dir, 'tmp.png'), nrow=5)
80- # print(f'Finish synthesizing {args.num} samples.')
81- return images
82-
83- return Fake_G (generator , fake_generator ), discriminator
84-
85- class Fake_G :
86-
87- def __init__ (self , G , g_function ):
88- self .G = G
89- self .g_function = g_function
90-
91- def __call__ (self , code ):
92- # print(f'code.shape {code.shape}')
93- return self .g_function (code )
94-
95- def mapping (self , code , label = None ):
96- return self .G .mapping (code , label = None )
97-
98- def zero_grad (self ):
1+
2+ """This file is modifed from synthesize.py. The goal is to return a generator which output an image in range [0., 1.]"""
3+
4+
5+ import os
6+ import argparse
7+ import subprocess
8+ from tqdm import tqdm
9+ import numpy as np
10+
11+ import torch
12+ from torchvision .utils import save_image
13+
14+ from .models import MODEL_ZOO
15+ from .models import build_generator , build_discriminator
16+ from .utils .misc import bool_parser
17+ from .utils .visualizer import HtmlPageVisualizer
18+
19+ def postprocess (images ):
20+ """change the range from [-1, 1] to [0., 1.]"""
21+ images = torch .clamp ((images + 1. ) / 2. , 0. , 1. )
22+ return images
23+
24+ def get_genforce (model_name , device , checkpoint_dir , use_discri = True , use_w_space = True , use_z_plus_space = False , repeat_w = True ):
25+
26+ trunc_psi = 0.7
27+ trunc_layers = 8
28+
29+ if model_name not in MODEL_ZOO :
30+ raise RuntimeError (f'model name `{ model_name } ` is not in model zoo' )
31+ model_config = MODEL_ZOO [model_name ].copy ()
32+ url = model_config .pop ('url' )
33+
34+ print (f'Building generator for model `{ model_name } `' )
35+ if model_name .startswith ('stylegan' ):
36+ generator = build_generator (** model_config , repeat_w = repeat_w )
37+ else :
38+ generator = build_generator (** model_config )
39+ synthesis_kwargs = dict (trunc_psi = trunc_psi ,
40+ trunc_layers = trunc_layers )
41+
42+ # Build discriminator
43+ if use_discri :
44+ print (f'Building discriminator for model `{ model_name } ` ...' )
45+ discriminator = build_discriminator (** model_config )
46+ else :
47+ discriminator = None
48+
49+ # load checkpoints
50+ os .makedirs (os .path .join (checkpoint_dir , 'genforce' ), exist_ok = True )
51+ ckpt_path = os .path .join (checkpoint_dir , 'genforce' , f'{ model_name } .pth' )
52+
53+ if not os .path .exists (ckpt_path ):
54+ print (f'Download checkpoint { model_name } from { url } ...' )
55+ subprocess .call (['wget' , '--quiet' , '-O' , ckpt_path , url ])
56+
57+ checkpoint = torch .load (ckpt_path )
58+
59+ if 'generator_smooth' in checkpoint :
60+ generator .load_state_dict (checkpoint ['generator_smooth' ])
61+ else :
62+ generator .load_state_dict (checkpoint ['generator' ])
63+ generator = generator .to (device )
64+ generator .eval ()
65+ if use_discri :
66+ discriminator .load_state_dict (checkpoint ['discriminator' ])
67+ discriminator = discriminator .to (device )
68+ discriminator .eval ()
69+ print ('Finish loading checkpoint.' )
70+
71+ def fake_generator (code ):
72+ # Sample and synthesize.
73+ # print(f'Synthesizing {args.num} samples ...')
74+ # code = torch.randn(args.batch_size, generator.z_space_dim).cuda()
75+ if use_z_plus_space :
76+ code = generator .mapping (code )['w' ]
77+ code = code .view (- 1 , generator .num_layers , generator .w_space_dim )
78+ images = generator (code , ** synthesis_kwargs , use_w_space = use_w_space )['image' ]
79+ images = postprocess (images )
80+ # save_image(images, os.path.join(work_dir, 'tmp.png'), nrow=5)
81+ # print(f'Finish synthesizing {args.num} samples.')
82+ return images
83+
84+ return Fake_G (generator , fake_generator ), discriminator
85+
86+ class Fake_G :
87+
88+ def __init__ (self , G , g_function ):
89+ self .G = G
90+ self .g_function = g_function
91+
92+ def __call__ (self , code ):
93+ # print(f'code.shape {code.shape}')
94+ return self .g_function (code )
95+
96+ def mapping (self , code , label = None ):
97+ return self .G .mapping (code , label = None )
98+
99+ def zero_grad (self ):
99100 self .G .zero_grad ()
0 commit comments