Skip to content

Commit 2991284

Browse files
committed
cuda
1 parent edaf570 commit 2991284

File tree

13 files changed

+819
-246
lines changed

13 files changed

+819
-246
lines changed

benchmarks/benchmark_chamfer.py

Lines changed: 142 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
import chamfer
88

9+
try:
10+
from pytorch3d.loss import chamfer_distance as pytorch3d_chamfer
11+
except ImportError: # pragma: no cover - optional dependency
12+
pytorch3d_chamfer = None
913

1014
def chunked_brute_force(query: torch.Tensor, reference: torch.Tensor, chunk: int = 1024) -> Tuple[torch.Tensor, torch.Tensor]:
1115
"""Memory-friendly brute-force NN by processing reference points in chunks."""
@@ -65,6 +69,11 @@ def mps_sync() -> None:
6569
torch.mps.synchronize()
6670

6771

72+
def cuda_sync() -> None:
73+
if torch.cuda.is_available():
74+
torch.cuda.synchronize()
75+
76+
6877
def main() -> None:
6978
parser = argparse.ArgumentParser(description="Benchmark chamfer nearest neighbour implementation.")
7079
parser.add_argument("--n", type=int, default=5_000, help="Number of points per set")
@@ -77,84 +86,183 @@ def main() -> None:
7786
a_cpu = torch.rand(args.n, args.dims)
7887
b_cpu = torch.rand(args.n, args.dims)
7988
mps_available = torch.backends.mps.is_available()
89+
cuda_available = torch.cuda.is_available()
8090

81-
# Warmups
82-
chamfer.closest_points(a_cpu[:256], b_cpu[:256], use_mps=False)
83-
chunked_brute_force(a_cpu[:512], b_cpu[:512], chunk=args.chunk)
91+
# Warmups to trigger compilation/allocation outside timing loops.
92+
chunked_chamfer_loss(a_cpu[:512], b_cpu[:512], chunk=args.chunk)
93+
chamfer.chamfer_distance(a_cpu[:256], b_cpu[:256], use_mps=False)
94+
if pytorch3d_chamfer is not None:
95+
pytorch3d_chamfer(a_cpu[:256].unsqueeze(0), b_cpu[:256].unsqueeze(0))
8496

8597
a_mps = b_mps = None
8698
if mps_available:
8799
a_mps = a_cpu.to("mps")
88100
b_mps = b_cpu.to("mps")
89-
chamfer.closest_points(a_mps[:256], b_mps[:256], use_mps=True)
90-
91-
brute_fwd_time = time_call(lambda: chunked_brute_force(a_cpu, b_cpu, chunk=args.chunk), repeat=args.repeat)
92-
cpu_kd_time = time_call(lambda: chamfer.closest_points(a_cpu, b_cpu, use_mps=False), repeat=args.repeat)
93-
94-
brute_grad_time = time_call(
95-
lambda: chunked_chamfer_loss(
96-
a_cpu.clone().requires_grad_(True),
97-
b_cpu.clone().requires_grad_(True),
98-
chunk=args.chunk,
99-
).backward(),
100-
repeat=args.repeat,
101-
)
101+
chamfer.chamfer_distance(a_mps[:256], b_mps[:256], use_mps=True)
102+
103+
a_cuda = b_cuda = None
104+
if cuda_available:
105+
a_cuda = a_cpu.to("cuda")
106+
b_cuda = b_cpu.to("cuda")
107+
chamfer.chamfer_distance(a_cuda[:256], b_cuda[:256])
108+
if pytorch3d_chamfer is not None:
109+
pytorch3d_chamfer(a_cuda[:256].unsqueeze(0), b_cuda[:256].unsqueeze(0))
110+
cuda_sync()
111+
112+
def brute_forward() -> None:
113+
chunked_chamfer_loss(a_cpu, b_cpu, chunk=args.chunk)
114+
115+
def brute_backward() -> None:
116+
a = a_cpu.clone().requires_grad_(True)
117+
b = b_cpu.clone().requires_grad_(True)
118+
loss = chunked_chamfer_loss(a, b, chunk=args.chunk)
119+
loss.backward()
120+
121+
brute_forward_time = time_call(brute_forward, repeat=args.repeat)
122+
brute_backward_time = time_call(brute_backward, repeat=args.repeat)
102123

103-
def cpu_grad() -> None:
124+
def kd_cpu_forward() -> None:
125+
chamfer.chamfer_distance(a_cpu, b_cpu, use_mps=False)
126+
127+
def kd_cpu_backward() -> None:
104128
a = a_cpu.clone().requires_grad_(True)
105129
b = b_cpu.clone().requires_grad_(True)
106130
loss = chamfer.chamfer_distance(a, b, use_mps=False)
107131
loss.backward()
108132

109-
cpu_grad_time = time_call(cpu_grad, repeat=args.repeat)
133+
cpu_forward_time = time_call(kd_cpu_forward, repeat=args.repeat)
134+
cpu_backward_time = time_call(kd_cpu_backward, repeat=args.repeat)
135+
136+
kd_cuda_forward_time = None
137+
kd_cuda_backward_time = None
138+
pytorch3d_cuda_forward_time = None
139+
pytorch3d_cuda_backward_time = None
140+
if cuda_available and a_cuda is not None and b_cuda is not None:
141+
def kd_cuda_forward() -> None:
142+
chamfer.chamfer_distance(a_cuda, b_cuda)
143+
144+
kd_cuda_forward_time = time_call(kd_cuda_forward, sync=cuda_sync, repeat=args.repeat)
145+
146+
def kd_cuda_backward() -> None:
147+
a = a_cuda.clone().requires_grad_(True)
148+
b = b_cuda.clone().requires_grad_(True)
149+
loss = chamfer.chamfer_distance(a, b)
150+
loss.backward()
151+
152+
kd_cuda_backward_time = time_call(kd_cuda_backward, sync=cuda_sync, repeat=args.repeat)
153+
154+
if pytorch3d_chamfer is not None:
155+
def pyt3d_cuda_forward() -> None:
156+
loss, _ = pytorch3d_chamfer(a_cuda.unsqueeze(0), b_cuda.unsqueeze(0))
157+
return loss
158+
159+
pyt3d_cuda_forward_time = time_call(pyt3d_cuda_forward, sync=cuda_sync, repeat=args.repeat)
160+
161+
def pyt3d_cuda_backward() -> None:
162+
a = a_cuda.unsqueeze(0).clone().requires_grad_(True)
163+
b = b_cuda.unsqueeze(0).clone().requires_grad_(True)
164+
loss, _ = pytorch3d_chamfer(a, b)
165+
loss.backward()
166+
167+
pyt3d_cuda_backward_time = time_call(pyt3d_cuda_backward, sync=cuda_sync, repeat=args.repeat)
110168

111-
kd_mps_time = None
112-
mps_grad_time = None
169+
kd_mps_forward_time = None
170+
kd_mps_backward_time = None
113171
if mps_available and a_mps is not None and b_mps is not None:
114-
kd_mps_time = time_call(
115-
lambda: chamfer.closest_points(a_mps, b_mps, use_mps=True),
116-
sync=mps_sync,
117-
repeat=args.repeat,
118-
)
172+
def kd_mps_forward() -> None:
173+
chamfer.chamfer_distance(a_mps, b_mps, use_mps=True)
174+
175+
kd_mps_forward_time = time_call(kd_mps_forward, sync=mps_sync, repeat=args.repeat)
119176

120-
def mps_grad() -> None:
177+
def kd_mps_backward() -> None:
121178
a = a_mps.clone().requires_grad_(True)
122179
b = b_mps.clone().requires_grad_(True)
123180
loss = chamfer.chamfer_distance(a, b, use_mps=True)
124181
loss.backward()
125182

126-
mps_grad_time = time_call(mps_grad, sync=mps_sync, repeat=args.repeat)
183+
kd_mps_backward_time = time_call(kd_mps_backward, sync=mps_sync, repeat=args.repeat)
184+
185+
pyt3d_cpu_forward_time = None
186+
pyt3d_cpu_backward_time = None
187+
if pytorch3d_chamfer is not None:
188+
def pyt3d_cpu_forward() -> None:
189+
loss, _ = pytorch3d_chamfer(a_cpu.unsqueeze(0), b_cpu.unsqueeze(0))
190+
return loss
191+
192+
pyt3d_cpu_forward_time = time_call(pyt3d_cpu_forward, repeat=args.repeat)
193+
194+
def pyt3d_cpu_backward() -> None:
195+
a = a_cpu.unsqueeze(0).clone().requires_grad_(True)
196+
b = b_cpu.unsqueeze(0).clone().requires_grad_(True)
197+
loss, _ = pytorch3d_chamfer(a, b)
198+
loss.backward()
199+
200+
pyt3d_cpu_backward_time = time_call(pyt3d_cpu_backward, repeat=args.repeat)
127201

128202
# Prepare table rows
129203
rows = []
130204

131205
rows.append(
132206
(
133207
"Brute force",
134-
f"{brute_fwd_time:.3f} s",
135-
f"{brute_grad_time:.3f} s",
208+
f"{brute_forward_time:.3f} s",
209+
f"{brute_backward_time:.3f} s",
136210
)
137211
)
138212

139213
rows.append(
140214
(
141215
"KD-tree CPU",
142-
f"{cpu_kd_time:.3f} s ({brute_fwd_time / cpu_kd_time:.2f}x)",
143-
f"{cpu_grad_time:.3f} s ({brute_grad_time / cpu_grad_time:.2f}x)",
216+
f"{cpu_forward_time:.3f} s ({brute_forward_time / cpu_forward_time:.2f}x)",
217+
f"{cpu_backward_time:.3f} s ({brute_backward_time / cpu_backward_time:.2f}x)",
144218
)
145219
)
146220

147-
if kd_mps_time is not None and mps_grad_time is not None:
221+
if kd_cuda_forward_time is not None and kd_cuda_backward_time is not None:
222+
rows.append(
223+
(
224+
"KD-tree CUDA",
225+
f"{kd_cuda_forward_time:.3f} s ({brute_forward_time / kd_cuda_forward_time:.2f}x)",
226+
f"{kd_cuda_backward_time:.3f} s ({brute_backward_time / kd_cuda_backward_time:.2f}x)",
227+
)
228+
)
229+
else:
230+
rows.append(("KD-tree CUDA", "n/a", "n/a"))
231+
232+
if kd_mps_forward_time is not None and kd_mps_backward_time is not None:
148233
rows.append(
149234
(
150235
"KD-tree MPS",
151-
f"{kd_mps_time:.3f} s ({brute_fwd_time / kd_mps_time:.2f}x)",
152-
f"{mps_grad_time:.3f} s ({brute_grad_time / mps_grad_time:.2f}x)",
236+
f"{kd_mps_forward_time:.3f} s ({brute_forward_time / kd_mps_forward_time:.2f}x)",
237+
f"{kd_mps_backward_time:.3f} s ({brute_backward_time / kd_mps_backward_time:.2f}x)",
153238
)
154239
)
155240
else:
156241
rows.append(("KD-tree MPS", "n/a", "n/a"))
157242

243+
if pytorch3d_chamfer is not None and pyt3d_cpu_forward_time is not None and pyt3d_cpu_backward_time is not None:
244+
rows.append(
245+
(
246+
"PyTorch3D CPU",
247+
f"{pyt3d_cpu_forward_time:.3f} s ({brute_forward_time / pyt3d_cpu_forward_time:.2f}x)",
248+
f"{pyt3d_cpu_backward_time:.3f} s ({brute_backward_time / pyt3d_cpu_backward_time:.2f}x)",
249+
)
250+
)
251+
252+
if (
253+
pyt3d_cuda_forward_time is not None
254+
and pyt3d_cuda_backward_time is not None
255+
):
256+
rows.append(
257+
(
258+
"PyTorch3D CUDA",
259+
f"{pyt3d_cuda_forward_time:.3f} s ({brute_forward_time / pyt3d_cuda_forward_time:.2f}x)",
260+
f"{pyt3d_cuda_backward_time:.3f} s ({brute_backward_time / pyt3d_cuda_backward_time:.2f}x)",
261+
)
262+
)
263+
else:
264+
rows.append(("PyTorch3D CUDA", "n/a", "n/a"))
265+
158266
header = ("Method", "Forward", "Backward")
159267
widths = [max(len(col), max(len(row[i]) for row in rows)) for i, col in enumerate(header)]
160268

0 commit comments

Comments
 (0)