@@ -42,11 +42,11 @@ def nop_args(
4242def do_bench_walltime (fn ):
4343 print ("Compiling..." )
4444 fn ()
45- torch .cuda .synchronize ()
45+ torch .xpu .synchronize ()
4646
4747 for _ in range (1000 ):
4848 fn ()
49- torch .cuda .synchronize ()
49+ torch .xpu .synchronize ()
5050
5151 n_repeat = 10000
5252
@@ -55,11 +55,11 @@ def do_bench_walltime(fn):
5555 for _ in range (25 ):
5656 print ("Running %d benchmarking iterations..." % n_repeat )
5757 # Benchmark
58- torch .cuda .synchronize ()
58+ torch .xpu .synchronize ()
5959 start_time = time .time ()
6060 for _ in range (n_repeat ):
6161 fn ()
62- torch .cuda .synchronize ()
62+ torch .xpu .synchronize ()
6363 end_time = time .time ()
6464 wall_time_ms = (end_time - start_time ) * 1e3 / n_repeat
6565 mses .append (wall_time_ms )
@@ -71,7 +71,7 @@ def do_bench_walltime(fn):
7171 profile .enable ()
7272 for _ in range (n_repeat ):
7373 fn ()
74- torch .cuda .synchronize ()
74+ torch .xpu .synchronize ()
7575 profile .disable ()
7676 stats = pstats .Stats (profile )
7777 stats .sort_stats ("time" )
@@ -81,9 +81,9 @@ def do_bench_walltime(fn):
8181
8282def main (use_tensor_desc : bool ):
8383 if use_tensor_desc :
84- targs = [TensorDescriptor .from_tensor (torch .zeros (1 , 16 , device = "cuda " ), block_shape = [1 , 16 ]) for _ in range (5 )]
84+ targs = [TensorDescriptor .from_tensor (torch .zeros (1 , 16 , device = "xpu " ), block_shape = [1 , 16 ]) for _ in range (5 )]
8585 else :
86- targs = [torch .zeros (1 , device = "cuda " ) for _ in range (5 )]
86+ targs = [torch .zeros (1 , device = "xpu " ) for _ in range (5 )]
8787 ncargs = [0 , 1 , 1024 , 2 ** 31 - 1 , 2 ** 64 - 1 , False , True , None , (16 , 16 )]
8888 cargs = [32 , False , True , 0 , 64 ]
8989
0 commit comments