Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit c03a0b7

Browse files
committed
Fix for trained models loaded from script directory. Closes #67.
1 parent a69d2ae commit c03a0b7

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

enhance.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -368,23 +368,26 @@ def list_generator_layers(self):
368368
name = list(self.network.keys())[list(self.network.values()).index(l)]
369369
yield (name, l)
370370

371+
def get_filename(self):
372+
filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args.zoom, args.type, args.model, __version__)
373+
return os.path.join(os.path.dirname(__file__), filename)
374+
371375
def save_generator(self):
372376
def cast(p): return p.get_value().astype(np.float16)
373377
params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()}
374378
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters'] + \
375379
['generator_upscale', 'generator_downscale']}
376-
filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args.zoom, args.type, args.model, __version__)
377-
pickle.dump((config, params), bz2.open(filename, 'wb'))
378-
print(' - Saved model as `{}` after training.'.format(filename))
380+
381+
pickle.dump((config, params), bz2.open(self.get_filename(), 'wb'))
382+
print(' - Saved model as `{}` after training.'.format(self.get_filename()))
379383

380384
def load_model(self):
381-
filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args.zoom, args.type, args.model, __version__)
382-
if not os.path.exists(filename):
385+
if not os.path.exists(self.get_filename()):
383386
if args.train: return {}, {}
384387
error("Model file with pre-trained convolution layers not found. Download it here...",
385-
"https://github.com/alexjc/neural-enhance/releases/download/v%s/%s"%(__version__, filename))
386-
print(' - Loaded file `{}` with trained model.'.format(filename))
387-
return pickle.load(bz2.open(filename, 'rb'))
388+
"https://github.com/alexjc/neural-enhance/releases/download/v%s/%s"%(__version__, self.get_filename()))
389+
print(' - Loaded file `{}` with trained model.'.format(self.get_filename()))
390+
return pickle.load(bz2.open(self.get_filename(), 'rb'))
388391

389392
def load_generator(self, params):
390393
if len(params) == 0: return

0 commit comments

Comments
 (0)