@@ -42,6 +42,7 @@ class EvalTask:
42
42
impl : Any
43
43
correctness_tests : List [Any ]
44
44
performance_tests : List [Any ]
45
+ device : str
45
46
46
47
47
48
@dataclass
@@ -100,8 +101,17 @@ def _worker_process(worker_id, task_queue, result_queue):
100
101
if isinstance (impl , str ):
101
102
impl = get_operator (impl )
102
103
104
+ device = torch .device (task .device )
105
+
106
+ def test_to_device_iterator (tests , device ):
107
+ for test in tests :
108
+ yield test_to_device (test , device )
109
+
110
+ correctness_tests = test_to_device_iterator (task .correctness_tests , device )
111
+ performance_tests = test_to_device_iterator (task .performance_tests , device )
112
+
103
113
correctness_score , performance_score , test_data = eval_one_op (
104
- op , impl , task . correctness_tests , task . performance_tests
114
+ op , impl , correctness_tests , performance_tests
105
115
)
106
116
result = EvalResult (
107
117
task_id = task .task_id ,
@@ -118,6 +128,8 @@ def _worker_process(worker_id, task_queue, result_queue):
118
128
)
119
129
logger .error (error_msg )
120
130
result_queue .put (ProcessDeathSignal (worker_id , error_msg ))
131
+ torch .cuda .synchronize ()
132
+ torch .cuda .empty_cache ()
121
133
break
122
134
result = EvalResult (
123
135
task_id = task .task_id ,
@@ -158,6 +170,37 @@ def _worker_process(worker_id, task_queue, result_queue):
158
170
logger .info (f"Worker { worker_id } exiting" )
159
171
160
172
173
+ def args_to_device (value , device ):
174
+ if isinstance (value , torch .Tensor ):
175
+ return value .to (device )
176
+ elif isinstance (value , list ):
177
+ return [args_to_device (item , device ) for item in value ]
178
+ elif isinstance (value , tuple ):
179
+ return tuple (args_to_device (item , device ) for item in value )
180
+ elif isinstance (value , dict ):
181
+ return {key : args_to_device (item , device ) for key , item in value .items ()}
182
+ else :
183
+ return value
184
+
185
+
186
+ def find_device (test ):
187
+ if isinstance (test , torch .Tensor ):
188
+ return test .device
189
+ elif isinstance (test , list ):
190
+ for item in test :
191
+ return find_device (item )
192
+ elif isinstance (test , dict ):
193
+ for item in test .values ():
194
+ return find_device (item )
195
+ return None
196
+
197
+
198
+ def test_to_device (test , device ):
199
+ test .args = args_to_device (test .args , device )
200
+ test .kwargs = args_to_device (test .kwargs , device )
201
+ return test
202
+
203
+
161
204
class MultiprocessingEvaluator :
162
205
def __init__ (self , num_workers : int = 1 ):
163
206
assert num_workers <= torch .cuda .device_count (), "performance will be suboptimal"
@@ -183,12 +226,26 @@ def submit_task(self, op, impl, correctness_tests, performance_tests) -> int:
183
226
if not is_pickleable (impl ):
184
227
impl = _extract_spec_name_from_op (impl )
185
228
229
+ orig_device = None
230
+ cpu_correctness_tests = []
231
+ for test in correctness_tests :
232
+ if orig_device is None :
233
+ orig_device = find_device (test )
234
+ cpu_correctness_tests .append (test_to_device (test , torch .device ("cpu" )))
235
+ if orig_device is None :
236
+ orig_device = torch .device ("cuda" )
237
+
238
+ cpu_performance_tests = []
239
+ for test in performance_tests :
240
+ cpu_performance_tests .append (test_to_device (test , torch .device ("cpu" )))
241
+
186
242
task = EvalTask (
187
243
task_id = task_id ,
188
244
op = op ,
189
245
impl = impl ,
190
- correctness_tests = list (correctness_tests ),
191
- performance_tests = list (performance_tests ),
246
+ correctness_tests = cpu_correctness_tests ,
247
+ performance_tests = cpu_performance_tests ,
248
+ device = str (orig_device ),
192
249
)
193
250
194
251
self .task_queue .put (task )
0 commit comments