66
77import chamfer
88
9+ try :
10+ from pytorch3d .loss import chamfer_distance as pytorch3d_chamfer
11+ except ImportError : # pragma: no cover - optional dependency
12+ pytorch3d_chamfer = None
913
1014def chunked_brute_force (query : torch .Tensor , reference : torch .Tensor , chunk : int = 1024 ) -> Tuple [torch .Tensor , torch .Tensor ]:
1115 """Memory-friendly brute-force NN by processing reference points in chunks."""
@@ -65,6 +69,11 @@ def mps_sync() -> None:
6569 torch .mps .synchronize ()
6670
6771
72+ def cuda_sync () -> None :
73+ if torch .cuda .is_available ():
74+ torch .cuda .synchronize ()
75+
76+
6877def main () -> None :
6978 parser = argparse .ArgumentParser (description = "Benchmark chamfer nearest neighbour implementation." )
7079 parser .add_argument ("--n" , type = int , default = 5_000 , help = "Number of points per set" )
@@ -77,84 +86,183 @@ def main() -> None:
7786 a_cpu = torch .rand (args .n , args .dims )
7887 b_cpu = torch .rand (args .n , args .dims )
7988 mps_available = torch .backends .mps .is_available ()
89+ cuda_available = torch .cuda .is_available ()
8090
81- # Warmups
82- chamfer .closest_points (a_cpu [:256 ], b_cpu [:256 ], use_mps = False )
83- chunked_brute_force (a_cpu [:512 ], b_cpu [:512 ], chunk = args .chunk )
91+ # Warmups to trigger compilation/allocation outside timing loops.
92+ chunked_chamfer_loss (a_cpu [:512 ], b_cpu [:512 ], chunk = args .chunk )
93+ chamfer .chamfer_distance (a_cpu [:256 ], b_cpu [:256 ], use_mps = False )
94+ if pytorch3d_chamfer is not None :
95+ pytorch3d_chamfer (a_cpu [:256 ].unsqueeze (0 ), b_cpu [:256 ].unsqueeze (0 ))
8496
8597 a_mps = b_mps = None
8698 if mps_available :
8799 a_mps = a_cpu .to ("mps" )
88100 b_mps = b_cpu .to ("mps" )
89- chamfer .closest_points (a_mps [:256 ], b_mps [:256 ], use_mps = True )
90-
91- brute_fwd_time = time_call (lambda : chunked_brute_force (a_cpu , b_cpu , chunk = args .chunk ), repeat = args .repeat )
92- cpu_kd_time = time_call (lambda : chamfer .closest_points (a_cpu , b_cpu , use_mps = False ), repeat = args .repeat )
93-
94- brute_grad_time = time_call (
95- lambda : chunked_chamfer_loss (
96- a_cpu .clone ().requires_grad_ (True ),
97- b_cpu .clone ().requires_grad_ (True ),
98- chunk = args .chunk ,
99- ).backward (),
100- repeat = args .repeat ,
101- )
101+ chamfer .chamfer_distance (a_mps [:256 ], b_mps [:256 ], use_mps = True )
102+
103+ a_cuda = b_cuda = None
104+ if cuda_available :
105+ a_cuda = a_cpu .to ("cuda" )
106+ b_cuda = b_cpu .to ("cuda" )
107+ chamfer .chamfer_distance (a_cuda [:256 ], b_cuda [:256 ])
108+ if pytorch3d_chamfer is not None :
109+ pytorch3d_chamfer (a_cuda [:256 ].unsqueeze (0 ), b_cuda [:256 ].unsqueeze (0 ))
110+ cuda_sync ()
111+
112+ def brute_forward () -> None :
113+ chunked_chamfer_loss (a_cpu , b_cpu , chunk = args .chunk )
114+
115+ def brute_backward () -> None :
116+ a = a_cpu .clone ().requires_grad_ (True )
117+ b = b_cpu .clone ().requires_grad_ (True )
118+ loss = chunked_chamfer_loss (a , b , chunk = args .chunk )
119+ loss .backward ()
120+
121+ brute_forward_time = time_call (brute_forward , repeat = args .repeat )
122+ brute_backward_time = time_call (brute_backward , repeat = args .repeat )
102123
103- def cpu_grad () -> None :
124+ def kd_cpu_forward () -> None :
125+ chamfer .chamfer_distance (a_cpu , b_cpu , use_mps = False )
126+
127+ def kd_cpu_backward () -> None :
104128 a = a_cpu .clone ().requires_grad_ (True )
105129 b = b_cpu .clone ().requires_grad_ (True )
106130 loss = chamfer .chamfer_distance (a , b , use_mps = False )
107131 loss .backward ()
108132
109- cpu_grad_time = time_call (cpu_grad , repeat = args .repeat )
133+ cpu_forward_time = time_call (kd_cpu_forward , repeat = args .repeat )
134+ cpu_backward_time = time_call (kd_cpu_backward , repeat = args .repeat )
135+
136+ kd_cuda_forward_time = None
137+ kd_cuda_backward_time = None
138+ pytorch3d_cuda_forward_time = None
139+ pytorch3d_cuda_backward_time = None
140+ if cuda_available and a_cuda is not None and b_cuda is not None :
141+ def kd_cuda_forward () -> None :
142+ chamfer .chamfer_distance (a_cuda , b_cuda )
143+
144+ kd_cuda_forward_time = time_call (kd_cuda_forward , sync = cuda_sync , repeat = args .repeat )
145+
146+ def kd_cuda_backward () -> None :
147+ a = a_cuda .clone ().requires_grad_ (True )
148+ b = b_cuda .clone ().requires_grad_ (True )
149+ loss = chamfer .chamfer_distance (a , b )
150+ loss .backward ()
151+
152+ kd_cuda_backward_time = time_call (kd_cuda_backward , sync = cuda_sync , repeat = args .repeat )
153+
154+ if pytorch3d_chamfer is not None :
155+ def pyt3d_cuda_forward () -> None :
156+ loss , _ = pytorch3d_chamfer (a_cuda .unsqueeze (0 ), b_cuda .unsqueeze (0 ))
157+ return loss
158+
159+ pyt3d_cuda_forward_time = time_call (pyt3d_cuda_forward , sync = cuda_sync , repeat = args .repeat )
160+
161+ def pyt3d_cuda_backward () -> None :
162+ a = a_cuda .unsqueeze (0 ).clone ().requires_grad_ (True )
163+ b = b_cuda .unsqueeze (0 ).clone ().requires_grad_ (True )
164+ loss , _ = pytorch3d_chamfer (a , b )
165+ loss .backward ()
166+
167+ pyt3d_cuda_backward_time = time_call (pyt3d_cuda_backward , sync = cuda_sync , repeat = args .repeat )
110168
111- kd_mps_time = None
112- mps_grad_time = None
169+ kd_mps_forward_time = None
170+ kd_mps_backward_time = None
113171 if mps_available and a_mps is not None and b_mps is not None :
114- kd_mps_time = time_call (
115- lambda : chamfer .closest_points (a_mps , b_mps , use_mps = True ),
116- sync = mps_sync ,
117- repeat = args .repeat ,
118- )
172+ def kd_mps_forward () -> None :
173+ chamfer .chamfer_distance (a_mps , b_mps , use_mps = True )
174+
175+ kd_mps_forward_time = time_call (kd_mps_forward , sync = mps_sync , repeat = args .repeat )
119176
120- def mps_grad () -> None :
177+ def kd_mps_backward () -> None :
121178 a = a_mps .clone ().requires_grad_ (True )
122179 b = b_mps .clone ().requires_grad_ (True )
123180 loss = chamfer .chamfer_distance (a , b , use_mps = True )
124181 loss .backward ()
125182
126- mps_grad_time = time_call (mps_grad , sync = mps_sync , repeat = args .repeat )
183+ kd_mps_backward_time = time_call (kd_mps_backward , sync = mps_sync , repeat = args .repeat )
184+
185+ pyt3d_cpu_forward_time = None
186+ pyt3d_cpu_backward_time = None
187+ if pytorch3d_chamfer is not None :
188+ def pyt3d_cpu_forward () -> None :
189+ loss , _ = pytorch3d_chamfer (a_cpu .unsqueeze (0 ), b_cpu .unsqueeze (0 ))
190+ return loss
191+
192+ pyt3d_cpu_forward_time = time_call (pyt3d_cpu_forward , repeat = args .repeat )
193+
194+ def pyt3d_cpu_backward () -> None :
195+ a = a_cpu .unsqueeze (0 ).clone ().requires_grad_ (True )
196+ b = b_cpu .unsqueeze (0 ).clone ().requires_grad_ (True )
197+ loss , _ = pytorch3d_chamfer (a , b )
198+ loss .backward ()
199+
200+ pyt3d_cpu_backward_time = time_call (pyt3d_cpu_backward , repeat = args .repeat )
127201
128202 # Prepare table rows
129203 rows = []
130204
131205 rows .append (
132206 (
133207 "Brute force" ,
134- f"{ brute_fwd_time :.3f} s" ,
135- f"{ brute_grad_time :.3f} s" ,
208+ f"{ brute_forward_time :.3f} s" ,
209+ f"{ brute_backward_time :.3f} s" ,
136210 )
137211 )
138212
139213 rows .append (
140214 (
141215 "KD-tree CPU" ,
142- f"{ cpu_kd_time :.3f} s ({ brute_fwd_time / cpu_kd_time :.2f} x)" ,
143- f"{ cpu_grad_time :.3f} s ({ brute_grad_time / cpu_grad_time :.2f} x)" ,
216+ f"{ cpu_forward_time :.3f} s ({ brute_forward_time / cpu_forward_time :.2f} x)" ,
217+ f"{ cpu_backward_time :.3f} s ({ brute_backward_time / cpu_backward_time :.2f} x)" ,
144218 )
145219 )
146220
147- if kd_mps_time is not None and mps_grad_time is not None :
221+ if kd_cuda_forward_time is not None and kd_cuda_backward_time is not None :
222+ rows .append (
223+ (
224+ "KD-tree CUDA" ,
225+ f"{ kd_cuda_forward_time :.3f} s ({ brute_forward_time / kd_cuda_forward_time :.2f} x)" ,
226+ f"{ kd_cuda_backward_time :.3f} s ({ brute_backward_time / kd_cuda_backward_time :.2f} x)" ,
227+ )
228+ )
229+ else :
230+ rows .append (("KD-tree CUDA" , "n/a" , "n/a" ))
231+
232+ if kd_mps_forward_time is not None and kd_mps_backward_time is not None :
148233 rows .append (
149234 (
150235 "KD-tree MPS" ,
151- f"{ kd_mps_time :.3f} s ({ brute_fwd_time / kd_mps_time :.2f} x)" ,
152- f"{ mps_grad_time :.3f} s ({ brute_grad_time / mps_grad_time :.2f} x)" ,
236+ f"{ kd_mps_forward_time :.3f} s ({ brute_forward_time / kd_mps_forward_time :.2f} x)" ,
237+ f"{ kd_mps_backward_time :.3f} s ({ brute_backward_time / kd_mps_backward_time :.2f} x)" ,
153238 )
154239 )
155240 else :
156241 rows .append (("KD-tree MPS" , "n/a" , "n/a" ))
157242
243+ if pytorch3d_chamfer is not None and pyt3d_cpu_forward_time is not None and pyt3d_cpu_backward_time is not None :
244+ rows .append (
245+ (
246+ "PyTorch3D CPU" ,
247+ f"{ pyt3d_cpu_forward_time :.3f} s ({ brute_forward_time / pyt3d_cpu_forward_time :.2f} x)" ,
248+ f"{ pyt3d_cpu_backward_time :.3f} s ({ brute_backward_time / pyt3d_cpu_backward_time :.2f} x)" ,
249+ )
250+ )
251+
252+ if (
253+ pyt3d_cuda_forward_time is not None
254+ and pyt3d_cuda_backward_time is not None
255+ ):
256+ rows .append (
257+ (
258+ "PyTorch3D CUDA" ,
259+ f"{ pyt3d_cuda_forward_time :.3f} s ({ brute_forward_time / pyt3d_cuda_forward_time :.2f} x)" ,
260+ f"{ pyt3d_cuda_backward_time :.3f} s ({ brute_backward_time / pyt3d_cuda_backward_time :.2f} x)" ,
261+ )
262+ )
263+ else :
264+ rows .append (("PyTorch3D CUDA" , "n/a" , "n/a" ))
265+
158266 header = ("Method" , "Forward" , "Backward" )
159267 widths = [max (len (col ), max (len (row [i ]) for row in rows )) for i , col in enumerate (header )]
160268
0 commit comments