From 395d7bb9f6a992933c22630f48105437b40a5ebb Mon Sep 17 00:00:00 2001 From: hazirbas Date: Thu, 31 Aug 2017 17:44:25 +0200 Subject: [PATCH] add set_test_mode: Evoke evaluation mode for dropout and batch normalizatio layers during testing. --- models/base_model.py | 3 +++ models/cycle_gan_model.py | 4 ++++ models/pix2pix_model.py | 3 +++ models/test_model.py | 3 +++ test.py | 1 + 5 files changed, 14 insertions(+) diff --git a/models/base_model.py b/models/base_model.py index 36ceb43a8d6..e3b7a353579 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -23,6 +23,9 @@ def forward(self): def test(self): pass + def set_test_mode(self): + pass + def get_image_paths(self): pass diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index b3c52c7f630..662a1c8da76 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -91,6 +91,10 @@ def test(self): self.fake_A = self.netG_B.forward(self.real_B) self.rec_B = self.netG_A.forward(self.fake_A) + def set_test_mode(self): + self.netG_A.eval() + self.netG_B.eval() + # get image paths def get_image_paths(self): return self.image_paths diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index a524f2ceda4..44510d5ccd0 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -73,6 +73,9 @@ def test(self): self.fake_B = self.netG.forward(self.real_A) self.real_B = Variable(self.input_B, volatile=True) + def set_test_mode(self): + self.netG.eval() + # get image paths def get_image_paths(self): return self.image_paths diff --git a/models/test_model.py b/models/test_model.py index 03aef655aac..c8c3a977253 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -35,6 +35,9 @@ def test(self): self.real_A = Variable(self.input_A) self.fake_B = self.netG.forward(self.real_A) + def set_test_mode(self): + self.netG.eval() + # get image paths def get_image_paths(self): return self.image_paths diff --git a/test.py b/test.py index f019d1047c4..c439dad0bc4 100644 --- a/test.py +++ b/test.py @@ -21,6 +21,7 @@ web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test +model.set_test_mode() for i, data in enumerate(dataset): if i >= opt.how_many: break