1
+ import sys
1
2
import time
2
3
import torch
3
4
import inspect
@@ -38,6 +39,12 @@ def medium_transpose():
38
39
return (rand (32 , 12 , 64 , 64 ).transpose (- 1 , - 2 ),
39
40
rand (32 , 12 , 64 , 64 ).transpose (- 1 , - 2 ))
40
41
42
+ def medium2 ():
43
+ return (rand (32 , 3 , 224 , 224 ), rand (32 , 3 , 224 , 224 ))
44
+
45
+ def medium3d ():
46
+ return (rand (16 , 32 , 64 ), rand (16 , 32 , 64 ))
47
+
41
48
def medium_channels_last ():
42
49
return (rand (32 , 3 , 224 , 224 ).to (memory_format = torch .channels_last ),
43
50
rand (32 , 3 , 224 , 224 ).to (memory_format = torch .channels_last ))
@@ -56,6 +63,10 @@ def large_transpose():
56
63
return (rand (8192 , 8192 ).transpose (0 , 1 ),
57
64
rand (8192 , 8192 ).transpose (0 , 1 ))
58
65
66
+ def large_channels_last ():
67
+ return (rand (32 , 32 , 256 , 256 ).to (memory_format = torch .channels_last ),
68
+ rand (32 , 32 , 256 , 256 ).to (memory_format = torch .channels_last ))
69
+
59
70
def pathological_broadcast ():
60
71
return (rand (1 , 32 , 32 , 2 ), rand (1024 , 1 , 1 , 2 ))
61
72
@@ -89,14 +100,14 @@ def log(a):
89
100
def exp (a ):
90
101
return a .exp ()
91
102
92
- def pow (a ):
103
+ def square (a ):
93
104
return a ** 2
94
105
95
106
def fma (a , b ):
96
107
return a * b + b
97
108
98
109
def hardswish (a ):
99
- return a * (a + 3 ).clamp (0 , 6 ) / 6
110
+ return a * (a + 3.0 ).clamp (0.0 , 6.0 ) / 6.0
100
111
101
112
def native_hardswish (a ):
102
113
return torch ._C ._nn .hardswish (a )
@@ -107,19 +118,55 @@ def softplus(a):
107
118
def mish (a ):
108
119
return a * ((a * 1.0 ).exp ().log1p () / 1.0 ).tanh ()
109
120
121
+ # ------------------------------------------------------------------------------
122
+ # Helpers
123
+ # ------------------------------------------------------------------------------
124
+ def time_cpu (fn , args , iters ):
125
+ s = time .perf_counter ()
126
+ for _ in range (iters ):
127
+ fn (* args )
128
+ e = time .perf_counter ()
129
+ return e - s
130
+
131
+ def time_cuda (fn , args , iters ):
132
+ start = torch .cuda .Event (enable_timing = True )
133
+ end = torch .cuda .Event (enable_timing = True )
134
+ start .record ()
135
+ for _ in range (iters ):
136
+ fn (* args )
137
+ end .record ()
138
+ torch .cuda .synchronize ()
139
+ return start .elapsed_time (end ) / 1e3
140
+
141
+ def benchmark_with_timer (fn , args , timer ):
142
+ timer (fn , args , 3 )
143
+ calibration = timer (fn , args , 1 )
144
+ iters = int (1.0 / calibration )
145
+ return timer (fn , args , iters ) / iters
146
+
147
+ def benchmark (fn , args ):
148
+ timer = time_cpu if args [0 ].device .type == "cpu" else time_cuda
149
+ return benchmark_with_timer (fn , args , timer )
150
+
151
+ def micros (s ):
152
+ return f"{ s * 1e6 :.1f} "
153
+
110
154
shapes = [
111
155
scalar ,
112
156
small ,
113
157
small_2d ,
114
158
small_broadcast ,
115
159
medium ,
160
+ medium2 ,
161
+ medium3d ,
116
162
medium_sliced ,
117
163
medium_transpose ,
118
164
medium_channels_last ,
119
165
medium_broadcast ,
120
166
medium_broadcast_channels_last ,
121
167
large ,
122
168
large_transpose ,
169
+ large_channels_last ,
123
170
pathological_broadcast ,
124
171
]
125
172
@@ -133,20 +180,16 @@ def mish(a):
133
180
tanh ,
134
181
log ,
135
182
exp ,
136
- pow ,
183
+ square ,
137
184
fma ,
138
185
hardswish ,
139
186
native_hardswish ,
140
187
]
141
- #shapes = [large_transpose]
142
- #operators = [add]
143
- #shapes = [scalar]
144
- #operators = [add]
188
+
145
189
nope = set ()
146
190
for shape , operator in itertools .product (shapes , operators ):
147
191
nargs = len (inspect .signature (operator ).parameters )
148
192
args = shape ()[:nargs ]
149
- #print(f"{operator.__name__} {shape.__name__}")
150
193
151
194
try :
152
195
if shape == medium_transpose :
@@ -160,41 +203,13 @@ def mish(a):
160
203
ts_op = torch .jit .script (operator )
161
204
torch .testing .assert_allclose (operator (* args ), ts_op (* args ))
162
205
163
- def time_cpu (fn , args , iters ):
164
- s = time .perf_counter ()
165
- for _ in range (iters ):
166
- fn (* args )
167
- e = time .perf_counter ()
168
- return e - s
169
-
170
- def time_cuda (fn , args , iters ):
171
- start = torch .cuda .Event (enable_timing = True )
172
- end = torch .cuda .Event (enable_timing = True )
173
- start .record ()
174
- for _ in range (iters ):
175
- fn (* args )
176
- end .record ()
177
- torch .cuda .synchronize ()
178
- return start .elapsed_time (end ) / 1e3
179
-
180
- def benchmark_with_timer (fn , args , timer ):
181
- timer (fn , args , 3 )
182
- calibration = timer (fn , args , 1 )
183
- iters = int (1.0 / calibration )
184
- return timer (fn , args , iters ) / iters
185
-
186
- def benchmark (fn , args ):
187
- timer = time_cpu if args [0 ].device .type == "cpu" else time_cuda
188
- return benchmark_with_timer (fn , args , timer )
189
-
190
- def micros (s ):
191
- return f"{ s * 1e6 :.1f} "
192
206
207
+ print ("fuser,device,operator,shape,time" )
193
208
results = []
194
209
for shape , operator in itertools .product (shapes , operators ):
195
210
nargs = len (inspect .signature (operator ).parameters )
196
211
args = shape ()[:nargs ]
197
-
212
+
198
213
result = benchmark (operator , args )
199
214
print ("," .join (["eager" , args [0 ].device .type , operator .__name__ , shape .__name__ , micros (result )]))
200
215
try :
@@ -206,18 +221,9 @@ def micros(s):
206
221
result = benchmark (pw_op , args )
207
222
print ("," .join (["pointwise" , args [0 ].device .type , operator .__name__ , shape .__name__ , micros (result )]))
208
223
except Exception :
209
- #print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}")
210
- #nope.add((operator, shape))
211
224
print ("," .join (["pointwise" , args [0 ].device .type , operator .__name__ , shape .__name__ , micros (float ("nan" ))]))
212
225
213
226
ts_op = torch .jit .script (operator )
214
227
result = benchmark (ts_op , args )
215
228
print ("," .join (["fuser" , args [0 ].device .type , operator .__name__ , shape .__name__ , micros (result )]))
216
-
217
- # cpu
218
- # parallel cpu
219
- # cuda
220
-
221
- # casts
222
-
223
- # inplace?
229
+ sys .stdout .flush ()
0 commit comments