@@ -368,23 +368,26 @@ def list_generator_layers(self):
368368 name = list (self .network .keys ())[list (self .network .values ()).index (l )]
369369 yield (name , l )
370370
371+ def get_filename (self ):
372+ filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args .zoom , args .type , args .model , __version__ )
373+ return os .path .join (os .path .dirname (__file__ ), filename )
374+
371375 def save_generator (self ):
372376 def cast (p ): return p .get_value ().astype (np .float16 )
373377 params = {k : [cast (p ) for p in l .get_params ()] for (k , l ) in self .list_generator_layers ()}
374378 config = {k : getattr (args , k ) for k in ['generator_blocks' , 'generator_residual' , 'generator_filters' ] + \
375379 ['generator_upscale' , 'generator_downscale' ]}
376- filename = 'ne%ix-%s-%s-%s.pkl.bz2' % ( args . zoom , args . type , args . model , __version__ )
377- pickle .dump ((config , params ), bz2 .open (filename , 'wb' ))
378- print (' - Saved model as `{}` after training.' .format (filename ))
380+
381+ pickle .dump ((config , params ), bz2 .open (self . get_filename () , 'wb' ))
382+ print (' - Saved model as `{}` after training.' .format (self . get_filename () ))
379383
380384 def load_model (self ):
381- filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args .zoom , args .type , args .model , __version__ )
382- if not os .path .exists (filename ):
385+ if not os .path .exists (self .get_filename ()):
383386 if args .train : return {}, {}
384387 error ("Model file with pre-trained convolution layers not found. Download it here..." ,
385- "https://github.com/alexjc/neural-enhance/releases/download/v%s/%s" % (__version__ , filename ))
386- print (' - Loaded file `{}` with trained model.' .format (filename ))
387- return pickle .load (bz2 .open (filename , 'rb' ))
388+ "https://github.com/alexjc/neural-enhance/releases/download/v%s/%s" % (__version__ , self . get_filename () ))
389+ print (' - Loaded file `{}` with trained model.' .format (self . get_filename () ))
390+ return pickle .load (bz2 .open (self . get_filename () , 'rb' ))
388391
389392 def load_generator (self , params ):
390393 if len (params ) == 0 : return
0 commit comments