-
Notifications
You must be signed in to change notification settings - Fork 888
Description
I’ve noticed that many people say the rFID tested with the checkpoint provided by VAR is 0.92. However, when I followed the tutorial provided by VAR, the rFID I got was 2.70. I’m not sure what the reason is.
Firstly, we load imagenet val dataset, and resize by center_crop_arr (provided in LLamaGEN). We then obtain the reconstruction image, and saved in PNG format by Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png"). Finally, we process saved reconstrution images to a npz file by create_npz_from_sample_folder.
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
import numpy as np
import argparse
import itertools
from models import VQVAE, build_vae_var
from metric import PSNR, LPIPS, SSIM
from augmentation import center_crop_arr
def create_npz_from_sample_folder(sample_dir, num=50000):
"""
Builds a single .npz file from a folder of .png samples.
"""
samples = []
for i in tqdm(range(num), desc="Building .npz file from samples"):
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
npz_path = f"{sample_dir}.npz"
np.savez(npz_path, arr_0=samples)
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
return npz_path
def load_dataset(data_path, batch_size=16):
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
val_dataset = ImageFolder(root=os.path.join(data_path, 'val'), transform=transform)
len_val_set = len(val_dataset)
dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=6, drop_last=False)
return dataloader, len_val_set
def main():
sample_folder_dir = "/projects/yuanai/processed_data/rFID/baselines/VAR"
save_npz_name = "var_reconstruction_imagenet256.npz"
### load dataset
data_path = "/projects/yuanai/data/ImageNet/"
val_dataloader, len_val_set = load_dataset(data_path, batch_size=16)
num_fid_samples = 50000
###load the vae checkpoint
vae, var = build_vae_var(
V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
device='cuda', patch_nums= (1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
num_classes=1000, depth=16, shared_aln=False,
)
vae_ckpt = "/projects/yuanai/processed_data/checkpoint/VAR/vae_ch160v4096z32.pth"
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
vae = vae.cuda()
vae.eval()
psnr_metric = PSNR()
ssim_metric = SSIM()
lpips_metric = LPIPS()
ssim, psnr, lpips = 0.0, 0.0, 0.0
total = 0
for idx, (x, _) in enumerate(val_dataloader):
x = x.cuda()
with torch.no_grad():
x_rec = vae.img_to_reconstructed_img(x, v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), last_one=True)
batch_lpips = lpips_metric(x, x_rec).sum()
samples = torch.clamp(127.5 * x_rec + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
# Save samples to disk as individual .png files
for i, sample in enumerate(samples):
index = i + total
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
total += 16
x_norm = (x + 1.0)/2.0
x_rec_norm = (x_rec + 1.0)/2.0
batch_psnr = psnr_metric(x_norm, x_rec_norm).sum()
batch_ssim = ssim_metric(x_norm, x_rec_norm).sum()
ssim += batch_ssim.item()
psnr += batch_psnr.item()
lpips += batch_lpips.item()
eval_psnr = psnr/len_val_set
eval_ssim = ssim/len_val_set
eval_lpips = lpips/len_val_set
print("PSNR:"+str(eval_psnr)+" SSIM:"+str(eval_ssim)+ " LPIPS:"+str(eval_lpips))
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
if name == "main":
main()
Based the npz.file, we calculate rFID by Open AI toolkit by "python evaluator.py VIRTUAL_imagenet256_labeled.npz our_sampled_imagenet256.npz", then the calculated rFID is:
Inception Score: 56.86065673828125
FID: 2.70789722116308
sFID: 4.6903826389984715
Precision: 0.74194
Recall: 0.6662
Can anyone tell me why there is a problem with the rFID test results?