1
- import time
2
1
import argparse
2
+ import time
3
+
3
4
import torch
4
5
from torch .profiler import profile , ProfilerActivity
5
6
@@ -13,36 +14,43 @@ def Index_fill(input, indices, dim, device):
13
14
else :
14
15
output = input .index_fill (dim , indices , 2 )
15
16
17
+
16
18
def run_profile (input , indices , dim , cache_r , cache_w , device , num_iter ):
17
19
with profile (
18
- activities = [ProfilerActivity .CPU ,
19
- ProfilerActivity .XPU if device == 'xpu' else ProfilerActivity .CUDA ],
20
+ activities = [
21
+ ProfilerActivity .CPU ,
22
+ ProfilerActivity .XPU if device == "xpu" else ProfilerActivity .CUDA ,
23
+ ],
20
24
record_shapes = True ,
21
25
) as prof :
22
26
for i in range (num_iter ):
23
27
cache_r = cache_w * i
24
28
Index_fill (input , indices , dim , device )
25
- print (prof .key_averages ().table (sort_by = "{}_time_total" .format (device )))
29
+ print (prof .key_averages ().table (sort_by = f"{ device } _time_total" ))
30
+
26
31
27
32
def run_e2e (input , indices , dim , cache_r , cache_w , device , num_iter ):
28
- if device in [' xpu' , ' cuda' ]:
29
- torch .xpu .synchronize () if device == ' xpu' else torch .cuda .synchronize ()
33
+ if device in [" xpu" , " cuda" ]:
34
+ torch .xpu .synchronize () if device == " xpu" else torch .cuda .synchronize ()
30
35
t1 = time .time ()
31
36
for i in range (num_iter ):
32
37
cache_r = cache_w * i
33
38
Index_fill (input , indices , dim , device )
34
- if device in [' xpu' , ' cuda' ]:
35
- torch .xpu .synchronize () if device == ' xpu' else torch .cuda .synchronize ()
39
+ if device in [" xpu" , " cuda" ]:
40
+ torch .xpu .synchronize () if device == " xpu" else torch .cuda .synchronize ()
36
41
t2 = time .time ()
37
42
e2e_time = (t2 - t1 ) / num_iter
38
43
print ("E2E total time:" , f"{ float (e2e_time ):.20f} " )
39
44
45
+
40
46
def benchmark (args ):
41
47
for shape in shape_list :
42
48
for dtype in [torch .bfloat16 , torch .float16 , torch .float32 ]:
43
49
for dim in [0 , 1 ]:
44
50
input = torch .zeros (shape , dtype = dtype , device = args .device )
45
- indices = torch .linspace (0 , 1022 , steps = 512 , device = args .device ).to (torch .long )
51
+ indices = torch .linspace (0 , 1022 , steps = 512 , device = args .device ).to (
52
+ torch .long
53
+ )
46
54
y_0 = torch .ones ((512 , 1024 ), dtype = dtype , device = args .device )
47
55
y_1 = torch .randn ((1024 , 512 ), dtype = dtype , device = args .device )
48
56
cache_r = torch .randn ((1024 * 1024 * 1024 ), device = args .device )
@@ -62,24 +70,45 @@ def benchmark(args):
62
70
backward ,
63
71
)
64
72
if not args .e2e_only :
65
- run_profile (input , indices , dim , cache_r , cache_w , args .device , args .num_iter )
73
+ run_profile (
74
+ input ,
75
+ indices ,
76
+ dim ,
77
+ cache_r ,
78
+ cache_w ,
79
+ args .device ,
80
+ args .num_iter ,
81
+ )
66
82
67
83
if not args .profile_only :
68
- run_e2e (input , indices , dim , cache_r , cache_w , args .device , args .num_iter )
84
+ run_e2e (
85
+ input ,
86
+ indices ,
87
+ dim ,
88
+ cache_r ,
89
+ cache_w ,
90
+ args .device ,
91
+ args .num_iter ,
92
+ )
93
+
69
94
70
95
def parse_args ():
71
- parser = argparse .ArgumentParser (description = 'OP Benchmark' )
72
- parser .add_argument ('--device' , type = str , default = 'xpu' ,
73
- help = 'Device to run on (e.g., "cpu", "cuda", "xpu")' )
96
+ parser = argparse .ArgumentParser (description = "OP Benchmark" )
97
+ parser .add_argument (
98
+ "--device" ,
99
+ type = str ,
100
+ default = "xpu" ,
101
+ help = 'Device to run on (e.g., "cpu", "cuda", "xpu")' ,
102
+ )
74
103
group = parser .add_mutually_exclusive_group ()
75
- group .add_argument ('--profile-only' , action = 'store_true' ,
76
- help = 'Only Run profile timing' )
77
- group .add_argument ('--e2e-only' , action = 'store_true' ,
78
- help = 'Only Run E2E timing' )
79
- parser .add_argument ('--num-iter' , type = int , default = 20 ,
80
- help = 'Number of iterations' )
104
+ group .add_argument (
105
+ "--profile-only" , action = "store_true" , help = "Only Run profile timing"
106
+ )
107
+ group .add_argument ("--e2e-only" , action = "store_true" , help = "Only Run E2E timing" )
108
+ parser .add_argument ("--num-iter" , type = int , default = 20 , help = "Number of iterations" )
81
109
return parser .parse_args ()
82
110
111
+
83
112
if __name__ == "__main__" :
84
113
args = parse_args ()
85
114
benchmark (args )
0 commit comments