|
| 1 | +from PIL import Image |
| 2 | +import torch |
| 3 | +import torch.nn.functional as F |
| 4 | +import numpy as np |
| 5 | +from romatch import roma_extre |
| 6 | +from romatch.utils.utils import tensor_to_pil |
| 7 | + |
| 8 | + |
| 9 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 10 | +if torch.backends.mps.is_available(): |
| 11 | + device = torch.device("mps") |
| 12 | + |
| 13 | +if __name__ == "__main__": |
| 14 | + from argparse import ArgumentParser |
| 15 | + parser = ArgumentParser() |
| 16 | + parser.add_argument( |
| 17 | + "--im_A_path", default="demo/assets/berlin_A.jpg", type=str) |
| 18 | + parser.add_argument( |
| 19 | + "--im_B_path", default="demo/assets/berlin_B.jpg", type=str) |
| 20 | + parser.add_argument( |
| 21 | + "--save_path", default="demo/gif/roma_warp_berlin", type=str) |
| 22 | + |
| 23 | + args, _ = parser.parse_known_args() |
| 24 | + im1_path = args.im_A_path |
| 25 | + im2_path = args.im_B_path |
| 26 | + save_path = args.save_path |
| 27 | + |
| 28 | + # Create model |
| 29 | + roma_model = roma_extre( |
| 30 | + device=device, |
| 31 | + coarse_res=560, |
| 32 | + upsample_res=(864, 1152) |
| 33 | + ) |
| 34 | + roma_model.symmetric = False |
| 35 | + |
| 36 | + H, W = roma_model.get_output_resolution() |
| 37 | + im1 = Image.open(im1_path).resize((W, H)) |
| 38 | + im2 = Image.open(im2_path).resize((W, H)) |
| 39 | + |
| 40 | + # Match |
| 41 | + warp, certainty = roma_model.match(im1_path, im2_path, device=device) |
| 42 | + # Sampling not needed, but can be done with model.sample(warp, certainty) |
| 43 | + |
| 44 | + x1 = torch.tensor(np.array(im1)).permute(2, 0, 1).to(device) / 255 |
| 45 | + x2 = torch.tensor(np.array(im2)).permute(2, 0, 1).to(device) / 255 |
| 46 | + |
| 47 | + coords_A, coords_B = warp[..., :2], warp[..., 2:] |
| 48 | + for i, x in enumerate(np.linspace(0, 2 * np.pi, 200)): |
| 49 | + t = (1 + np.cos(x)) / 2 |
| 50 | + interp_warp = (1 - t) * coords_A + t * coords_B |
| 51 | + im2_transfer_rgb = F.grid_sample( |
| 52 | + x2[None], interp_warp[None], mode="bilinear", align_corners=False |
| 53 | + )[0] |
| 54 | + |
| 55 | + tensor_to_pil(im2_transfer_rgb, unnormalize=False)\ |
| 56 | + .save(f"{save_path}_{i:03d}.jpg") |
0 commit comments