diff --git a/README.md b/README.md index f3a7ff1d..19a111cd 100755 --- a/README.md +++ b/README.md @@ -130,6 +130,8 @@ th -ldisplay.start 8000 0.0.0.0 ``` Then open `http://(hostname):(port)/` in your browser to load the remote desktop. +L1 error is plotted to the display by default. Set the environment variable `display_plot` to a comma-seperated list of values `errL1`, `errG` and `errD` to visualize the L1, generator, and descriminator error respectively. For example, to plot only the generator and descriminator errors to the display instead of the default L1 error, set `display_plot="errG,errD"`. + ## Citation If you use this code for your research, please cite our paper Image-to-Image Translation Using Conditional Adversarial Networks: diff --git a/data/combine_A_and_B.py b/data/combine_A_and_B.py index 035830be..c9d5e4ca 100755 --- a/data/combine_A_and_B.py +++ b/data/combine_A_and_B.py @@ -43,8 +43,8 @@ if args.use_AB: name_AB = name_AB.replace('_A.', '.') # remove _A path_AB = os.path.join(img_fold_AB, name_AB) - im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR) - im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR) + im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) + im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) im_AB = np.concatenate([im_A, im_B], 1) cv2.imwrite(path_AB, im_AB) diff --git a/data/donkey_folder.lua b/data/donkey_folder.lua index 9f53dc01..bc6a81ea 100755 --- a/data/donkey_folder.lua +++ b/data/donkey_folder.lua @@ -33,12 +33,12 @@ local trainCache = paths.concat(cache, cache_prefix .. '_trainCache.t7') -------------------------------------------------------------------------------------------- local input_nc = opt.input_nc -- input channels local output_nc = opt.output_nc -local loadSize = {input_nc, opt.loadSize} -local sampleSize = {input_nc, opt.fineSize} +local loadSize = {input_nc, opt.imgWidth, opt.imgHeight} +local sampleSize = {input_nc, opt.imgWidth, opt.imgHeight} local preprocessAandB = function(imA, imB) - imA = image.scale(imA, loadSize[2], loadSize[2]) - imB = image.scale(imB, loadSize[2], loadSize[2]) + imA = image.scale(imA, loadSize[2], loadSize[3]) + imB = image.scale(imB, loadSize[2], loadSize[3]) local perm = torch.LongTensor{3, 2, 1} imA = imA:index(1, perm)--:mul(256.0): brg, rgb imA = imA:mul(2):add(-1) @@ -52,7 +52,7 @@ local preprocessAandB = function(imA, imB) local oW = sampleSize[2] - local oH = sampleSize[2] + local oH = sampleSize[3] local iH = imA:size(2) local iW = imA:size(3) @@ -80,10 +80,10 @@ end local function loadImageChannel(path) local input = image.load(path, 3, 'float') - input = image.scale(input, loadSize[2], loadSize[2]) + input = image.scale(input, loadSize[2], loadSize[3]) local oW = sampleSize[2] - local oH = sampleSize[2] + local oH = sampleSize[3] local iH = input:size(2) local iW = input:size(3) @@ -161,7 +161,7 @@ print('trainCache', trainCache) -- trainLoader = torch.load(trainCache) -- trainLoader.sampleHookTrain = trainHook -- trainLoader.loadSize = {input_nc, opt.loadSize, opt.loadSize} --- trainLoader.sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]} +-- trainLoader.sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[3]} -- trainLoader.serial_batches = opt.serial_batches -- trainLoader.split = 100 --else @@ -170,8 +170,8 @@ print('Creating train metadata') print('serial batch:, ', opt.serial_batches) trainLoader = dataLoader{ paths = {opt.data}, - loadSize = {input_nc, loadSize[2], loadSize[2]}, - sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]}, + loadSize = {input_nc, loadSize[2], loadSize[3]}, + sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[3]}, split = 100, serial_batches = opt.serial_batches, verbose = true @@ -189,4 +189,4 @@ do local nClasses = #trainLoader.classes assert(class:max() <= nClasses, "class logic has error") assert(class:min() >= 1, "class logic has error") -end \ No newline at end of file +end diff --git a/test.lua b/test.lua index 60676b73..45176dcc 100755 --- a/test.lua +++ b/test.lua @@ -13,7 +13,8 @@ opt = { DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) batchSize = 1, -- # images in batch loadSize = 256, -- scale images to this size - fineSize = 256, -- then crop to this size + imgWidth = 256, -- then crop to this size. Both should be multiples of 32... + imgHeight = 512, flip=0, -- horizontal mirroring data augmentation display = 1, -- display samples while training. 0 = false display_id = 200, -- display window id. @@ -69,8 +70,8 @@ else end ---------------------------------------------------------------------------- -local input = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize) -local target = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize) +local input = torch.FloatTensor(opt.batchSize,3,opt.imgHeight,opt.imgWidth) +local target = torch.FloatTensor(opt.batchSize,3,opt.imgHeight,opt.imgWidth) print('checkpoints_dir', opt.checkpoints_dir) local netG = util.load(paths.concat(opt.checkpoints_dir, opt.netG_name .. '.t7'), opt) @@ -129,18 +130,18 @@ for n=1,math.floor(opt.how_many/opt.batchSize) do print(output:size()) print(target:size()) for i=1, opt.batchSize do - image.save(paths.concat(image_dir,'input',filepaths_curr[i]), image.scale(input[i],input[i]:size(2),input[i]:size(3)/opt.aspect_ratio)) - image.save(paths.concat(image_dir,'output',filepaths_curr[i]), image.scale(output[i],output[i]:size(2),output[i]:size(3)/opt.aspect_ratio)) - image.save(paths.concat(image_dir,'target',filepaths_curr[i]), image.scale(target[i],target[i]:size(2),target[i]:size(3)/opt.aspect_ratio)) + image.save(paths.concat(image_dir,'input',filepaths_curr[i]), input[i])--image.scale(input[i],input[i]:size(3),input[i]:size(2))) --/opt.aspect_ratio)) + image.save(paths.concat(image_dir,'output',filepaths_curr[i]), output[i])--image.scale(output[i],output[i]:size(3),output[i]:size(2)))--/opt.aspect_ratio)) + image.save(paths.concat(image_dir,'target',filepaths_curr[i]), target[i])--image.scale(target[i],target[i]:size(3),target[i]:size(2)))--/opt.aspect_ratio)) end print('Saved images to: ', image_dir) if opt.display then if opt.preprocess == 'regular' then disp = require 'display' - disp.image(util.scaleBatch(input,100,100),{win=opt.display_id, title='input'}) - disp.image(util.scaleBatch(output,100,100),{win=opt.display_id+1, title='output'}) - disp.image(util.scaleBatch(target,100,100),{win=opt.display_id+2, title='target'}) + disp.image(util.scaleBatch(input,512,256),{win=opt.display_id, title='input'}) + disp.image(util.scaleBatch(output,512,256),{win=opt.display_id+1, title='output'}) + disp.image(util.scaleBatch(target,512,256),{win=opt.display_id+2, title='target'}) print('Displayed images') end @@ -164,4 +165,4 @@ for i=1, #filepaths do io.write('') end -io.write('') \ No newline at end of file +io.write('') diff --git a/train.lua b/train.lua old mode 100755 new mode 100644 index f1e1bc06..5d8133d0 --- a/train.lua +++ b/train.lua @@ -15,7 +15,8 @@ opt = { DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) batchSize = 1, -- # images in batch loadSize = 286, -- scale images to this size - fineSize = 256, -- then crop to this size + imgWidth = 256, -- then crop to this size. Both should be multiples of 32... + imgHeight = 512, ngf = 64, -- # of gen filters in first conv layer ndf = 64, -- # of discrim filters in first conv layer input_nc = 3, -- # of input image channels @@ -27,6 +28,7 @@ opt = { flip = 1, -- if flip the images for data argumentation display = 1, -- display samples while training. 0 = false display_id = 10, -- display window id. + display_plot = 'errL1', -- which loss values to plot over time. Accepted values include a comma seperated list of: errL1, errG, and errD gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X name = '', -- name of the experiment, should generally be passed on the command line which_direction = 'AtoB', -- AtoB or BtoA @@ -36,7 +38,7 @@ opt = { save_epoch_freq = 50, -- save a model every save_epoch_freq epochs (does not overwrite previously saved models) save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model) print_freq = 50, -- print the debug information every print_freq iterations - display_freq = 100, -- display the current results every display_freq iterations + display_freq = 20, -- display the current results every display_freq iterations save_display_freq = 5000, -- save the current display of results every save_display_freq_iterations continue_train=0, -- if continue training, load the latest model: 1: true, 0: false serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly @@ -165,11 +167,11 @@ optimStateD = { beta1 = opt.beta1, } ---------------------------------------------------------------------------- -local real_A = torch.Tensor(opt.batchSize, input_nc, opt.fineSize, opt.fineSize) -local real_B = torch.Tensor(opt.batchSize, output_nc, opt.fineSize, opt.fineSize) -local fake_B = torch.Tensor(opt.batchSize, output_nc, opt.fineSize, opt.fineSize) -local real_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize) -local fake_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize) +local real_A = torch.Tensor(opt.batchSize, input_nc, opt.imgWidth, opt.imgHeight) +local real_B = torch.Tensor(opt.batchSize, output_nc, opt.imgWidth, opt.imgHeight) +local fake_B = torch.Tensor(opt.batchSize, output_nc, opt.imgWidth, opt.imgHeight) +local real_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.imgWidth, opt.imgHeight) +local fake_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.imgWidth, opt.imgHeight) local errD, errG, errL1 = 0, 0, 0 local epoch_tm = torch.Timer() local tm = torch.Timer() @@ -314,6 +316,25 @@ file = torch.DiskFile(paths.concat(opt.checkpoints_dir, opt.name, 'opt.txt'), 'w file:writeObject(opt) file:close() +-- parse diplay_plot string into table +opt.display_plot = string.split(string.gsub(opt.display_plot, "%s+", ""), ",") +for k, v in ipairs(opt.display_plot) do + if not util.containsValue({"errG", "errD", "errL1"}, v) then + error(string.format('bad display_plot value "%s"', v)) + end +end + +-- display plot config +local plot_config = { + title = "Loss over time", + labels = {"epoch", unpack(opt.display_plot)}, + ylabel = "loss", +} + +-- display plot vars +local plot_data = {} +local plot_win + local counter = 0 for epoch = 1, opt.niter do epoch_tm:reset() @@ -328,22 +349,22 @@ for epoch = 1, opt.niter do -- (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x)) optim.adam(fGx, parametersG, optimStateG) - + -- display counter = counter + 1 if counter % opt.display_freq == 0 and opt.display then createRealFake() if opt.preprocess == 'colorization' then - local real_A_s = util.scaleBatch(real_A:float(),100,100) - local fake_B_s = util.scaleBatch(fake_B:float(),100,100) - local real_B_s = util.scaleBatch(real_B:float(),100,100) + local real_A_s = util.scaleBatch(real_A:float(),512,256) + local fake_B_s = util.scaleBatch(fake_B:float(),512,256) + local real_B_s = util.scaleBatch(real_B:float(),512,256) disp.image(util.deprocessL_batch(real_A_s), {win=opt.display_id, title=opt.name .. ' input'}) disp.image(util.deprocessLAB_batch(real_A_s, fake_B_s), {win=opt.display_id+1, title=opt.name .. ' output'}) disp.image(util.deprocessLAB_batch(real_A_s, real_B_s), {win=opt.display_id+2, title=opt.name .. ' target'}) else - disp.image(util.deprocess_batch(util.scaleBatch(real_A:float(),100,100)), {win=opt.display_id, title=opt.name .. ' input'}) - disp.image(util.deprocess_batch(util.scaleBatch(fake_B:float(),100,100)), {win=opt.display_id+1, title=opt.name .. ' output'}) - disp.image(util.deprocess_batch(util.scaleBatch(real_B:float(),100,100)), {win=opt.display_id+2, title=opt.name .. ' target'}) + disp.image(util.deprocess_batch(util.scaleBatch(real_A:float(),512,256)), {win=opt.display_id, title=opt.name .. ' input'}) + disp.image(util.deprocess_batch(util.scaleBatch(fake_B:float(),512,256)), {win=opt.display_id+1, title=opt.name .. ' output'}) + disp.image(util.deprocess_batch(util.scaleBatch(real_B:float(),512,256)), {win=opt.display_id+2, title=opt.name .. ' target'}) end end @@ -377,14 +398,30 @@ for epoch = 1, opt.niter do opt.serial_batches=serial_batches end - -- logging + -- logging and display plot if counter % opt.print_freq == 0 then + local loss = {errG=errG and errG or -1, errD=errD and errD or -1, errL1=errL1 and errL1 or -1} + local curItInBatch = ((i-1) / opt.batchSize) + local totalItInBatch = math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize) print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' .. ' Err_G: %.4f Err_D: %.4f ErrL1: %.4f'):format( - epoch, ((i-1) / opt.batchSize), - math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), + epoch, curItInBatch, totalItInBatch, tm:time().real / opt.batchSize, data_tm:time().real / opt.batchSize, - errG and errG or -1, errD and errD or -1, errL1 and errL1 or -1)) + errG, errD, errL1)) + + local plot_vals = { epoch + curItInBatch / totalItInBatch } + for k, v in ipairs(opt.display_plot) do + if loss[v] ~= nil then + plot_vals[#plot_vals + 1] = loss[v] + end + end + + -- update display plot + if opt.display then + table.insert(plot_data, plot_vals) + plot_config.win = plot_win + plot_win = disp.plot(plot_data, plot_config) + end end -- save latest model @@ -409,4 +446,4 @@ for epoch = 1, opt.niter do epoch, opt.niter, epoch_tm:time().real)) parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them parametersG, gradParametersG = netG:getParameters() -end \ No newline at end of file +end diff --git a/util/util.lua b/util/util.lua index 3042ef5a..78acd4b1 100755 --- a/util/util.lua +++ b/util/util.lua @@ -219,4 +219,11 @@ function util.cudnn(net) return cudnn_convert_custom(net, cudnn) end +function util.containsValue(table, value) + for k, v in pairs(table) do + if v == value then return true end + end + return false +end + return util