Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit f9d4d50

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents 637398c + 02d2fca commit f9d4d50

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Pre-trained models are provided in the GitHub releases. Training your own is a
6464
6565
# Train the model using an adversarial setup based on [4] below.
6666
python3.4 enhance.py --train "data/*.jpg" --model custom --scales=2 --epochs=250 \
67-
--perceptual-layer=conv5_2 --smoothness-weight=2e4 --adversary-weight=2e5 \
67+
--perceptual-layer=conv5_2 --smoothness-weight=2e4 --adversary-weight=1e3 \
6868
--generator-start=5 --discriminator-start=0 --adversarial-start=5 \
6969
--discriminator-size=64
7070

enhance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))
9797
# Load the underlying deep learning libraries based on the device specified. If you specify THEANO_FLAGS manually,
9898
# the code assumes you know what you are doing and they are not overriden!
9999
os.environ.setdefault('THEANO_FLAGS', 'floatX=float32,device={},force_device=True,allow_gc=True,'\
100-
'print_active_device=False,lib.cnmem=1.0'.format(args.device))
100+
'print_active_device=False'.format(args.device))
101101

102102
# Scientific & Imaging Libraries
103103
import numpy as np
@@ -436,7 +436,7 @@ def loss_total_variation(self, x):
436436
return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
437437

438438
def loss_adversarial(self, d):
439-
return T.mean(1.0 - T.nnet.softplus(d[args.batch_size:]))
439+
return T.mean(1.0 - T.nnet.softminus(d[args.batch_size:]))
440440

441441
def loss_discriminator(self, d):
442442
return T.mean(T.nnet.softminus(d[args.batch_size:]) - T.nnet.softplus(d[:args.batch_size]))

0 commit comments

Comments
 (0)