A decoder encoder based architecture is used.
There are 2 options for the generator encoder.
a. Resnet50 minus the last 2 layers b. Resnet50 + ASPP module
The Decoder network of the Generator network has seven upsampling convolutional blocks. Each upsampling convolutional block has an upsampling layer, followed by a convolutional layer, a batch normalization layer and a ReLU activation function.
The discriminator used here is the PatchGAN discriminator. The implementation here is inspired from the implementation of CycleGAN from
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Again, 2 different types of discriminator are used a. N Layer Patch gan discriminator, where the size of the patch is NxN, it is taken as 3x3 here b. Pixel patch Patch gan discriminator, the discriminator classsifies every pixel.
Use the dataroot argument to enter the directory where you have stored the data. Structure the data in the following way.
train
-alpha -bg -fg
test
-fg -trimap
The background I have used here is the MSCOCO dataset.
To train the model using Resnet50 without ASPP module
!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50 --name resnet50
To test the model using Resnet without ASPP module
!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50 --ntest 8 --model test --name resnet50
To train the model using Resnet50 using ASPP module
!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50ASPP --name resnet50ASPP
To test the model using Resnet50 using ASPP module
!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50ASPP --ntest 8 --model test
(Average Rank on alphamatting.com has been shown)
| Error type | Original implementation | Resnet50 +N Layer | Resnet50 + Pixel | Resnet50 + ASPP module |
|---|---|---|---|---|
| Sum of absolute differences | 11.7 | 42.8 | 43.8 | 53 |
| Mean square error | 15 | 45.8 | 45.6 | 54.2 |
| Gradient error | 14 | 52.9 | 52.7 | 55 |
| Connectivity error | 29.6 | 23.3 | 22.6 | 32.8 |
I used the following link for training dataset. Link to created dataset




