tensorflow2实现CycleGAN,移植修改自源码
更多次数的训练效果应该还会更好一些。
程序主体架构来自Monkeone,原作者可能是完全复现CycleGAN论文的算法与网络架构,我为了在小显存显卡上方便训练以及更好的网络效果,主要做了如下修改:
- 修改生成器,使用上采样层代替转置卷积
- 减小默认
resnet blocks和判别器隐层数量,方便小显存显卡训练 - 网络输入大小由
helper里cfg提供 - 添加
tensorboard记录训练时的图片 - 修复一处可能存在的
判别器优化器bug - 添加更多命令行参数,可直接指定学习率、恢复训练等
- 重写
helpers.py - 修改训练、测试逻辑
- 完成更多模型的训练,提供更多的演示效果与内置模型
- ...
准备两种不同风格或领域的图片,分为A和B,按如下方式放在data文件夹下:在data文件夹下新建文件夹,用存放在trainA、trainB、testA、testB文件夹,其中train内图片用于训练模型,test内图片用于最后的转换测试,
- 执行
python train.py -h查看帮助信息,一般执行如下命令即可开始训练:
python train.py --data_dir 数据集路径
- 使用
tensorboard查看训练状态:
tensorboard --logdir logs/ --bind_all
浏览器打开http://127.0.0.1:6006
- 如需在训练时查看生成图片信息,在训练时指定
--tensorboard_images_freq参数,参数为更新频率
- 执行
python test.py -h查看帮助信息,一般执行如下命令即可开始测试:
python test.py --data_dir 数据集路径
- 测试结果默认保存在数据集目录下
output文件夹里,也可使用out_dir参数指定保存路径
- Monkeone for main framework.











