-
Based on DCGAN, I write my gan code as follows: import basic_libraries
from torchmetrics.image.fid import FrechetInceptionDistance
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.Resize(args.image_size),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
trainset = torchvision.datasets.MNIST(root=args.dataroot, train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root=args.dataroot, train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.n_workers
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.n_workers
)
...
netD = NetD() # discriminator
netG = NetG() # generator
fid = FrechetInceptionDistance(feature=64)
netG.fid=fid
...
fid_scores = []
for epoch in range(args.n_epochs):
# Step 1: train gan model using trainloader
...
# Step 2: evaluate gan model using testloader
for image, _ in tqdm(testloader):
noise = torch.randn(args.batch_size, args.nz, 1, 1, device=device)
with torch.no_grad():
fake = netG(noise)
fake.data = fake.data.mul(0.5).add(0.5)
image.data = image.data.mul(0.5).add(0.5)
netG.fid.update(convert_image_dtype(fake, torch.uint8), real=False)
netG.fid.update(convert_image_dtype(image.to(device),torch.uint8),real=True)
current_fid = netG.fid.compute().item()
netG.fid.reset()
fid_scores.append(current_fid)
print(f'FID = {current_fid:.4f}') I want to ask whether this code is correct for my purpose: use torchmetrics to evaluate my gan model on MNIST. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Your code for evaluating FID with |
Beta Was this translation helpful? Give feedback.
Your code for evaluating FID with
torchmetrics.image.fid.FrechetInceptionDistance
broadly follows the right steps: applying appropriate image transformations, updating the metric with both real and generated images during evaluation, then computing and resetting the metric scores each epoch. Just ensure that your image tensors passed toupdate()
are correctly scaled and converted to uint8, as you are doing. Also, typically FID expects RGB images, so your grayscale-to-3-channel transform is appropriate for MNIST.One detail: the feature parameter
feature=64
inFrechetInceptionDistance(feature=64)
is unusual, as the default uses a pre-trained Inception model and the feature extractor is int…