Skip to content

Commit ed5a980

Browse files
committed
add 3D effect demo
1 parent b8165b1 commit ed5a980

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Follow the instructions on [RoMa](https://github.com/Parskatt/RoMa).
2121

2222
## Demo
2323

24-
A matching demo is provided in the [demos folder](demo).
24+
Two matching demos are provided in the [demos folder](demo). They are slightly adapted from the original RoMa demos for convenience and code clarity.
2525

2626
See [RoMa](https://github.com/Parskatt/RoMa?tab=readme-ov-file#demo--how-to-use) for more details.
2727

demo/demo_3D_effect.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from PIL import Image
2+
import torch
3+
import torch.nn.functional as F
4+
import numpy as np
5+
from romatch import roma_extre
6+
from romatch.utils.utils import tensor_to_pil
7+
8+
9+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10+
if torch.backends.mps.is_available():
11+
device = torch.device("mps")
12+
13+
if __name__ == "__main__":
14+
from argparse import ArgumentParser
15+
parser = ArgumentParser()
16+
parser.add_argument(
17+
"--im_A_path", default="demo/assets/berlin_A.jpg", type=str)
18+
parser.add_argument(
19+
"--im_B_path", default="demo/assets/berlin_B.jpg", type=str)
20+
parser.add_argument(
21+
"--save_path", default="demo/gif/roma_warp_berlin", type=str)
22+
23+
args, _ = parser.parse_known_args()
24+
im1_path = args.im_A_path
25+
im2_path = args.im_B_path
26+
save_path = args.save_path
27+
28+
# Create model
29+
roma_model = roma_extre(
30+
device=device,
31+
coarse_res=560,
32+
upsample_res=(864, 1152)
33+
)
34+
roma_model.symmetric = False
35+
36+
H, W = roma_model.get_output_resolution()
37+
im1 = Image.open(im1_path).resize((W, H))
38+
im2 = Image.open(im2_path).resize((W, H))
39+
40+
# Match
41+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
42+
# Sampling not needed, but can be done with model.sample(warp, certainty)
43+
44+
x1 = torch.tensor(np.array(im1)).permute(2, 0, 1).to(device) / 255
45+
x2 = torch.tensor(np.array(im2)).permute(2, 0, 1).to(device) / 255
46+
47+
coords_A, coords_B = warp[..., :2], warp[..., 2:]
48+
for i, x in enumerate(np.linspace(0, 2 * np.pi, 200)):
49+
t = (1 + np.cos(x)) / 2
50+
interp_warp = (1 - t) * coords_A + t * coords_B
51+
im2_transfer_rgb = F.grid_sample(
52+
x2[None], interp_warp[None], mode="bilinear", align_corners=False
53+
)[0]
54+
55+
tensor_to_pil(im2_transfer_rgb, unnormalize=False)\
56+
.save(f"{save_path}_{i:03d}.jpg")

demo/gif/.gitkeep

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*
2+
!.gitignore

0 commit comments

Comments
 (0)