@@ -159,39 +159,20 @@ def get_options_from_kwargs_and_tuner_cache(name, cache_file, options_cache, *in
159
159
# TC autotuner class - ATen
160
160
###############################################################################
161
161
class TcAutotuner (object ):
162
- def __init__ (
163
- self ,
164
- tc_lang ,
165
- pop_size = 10 ,
166
- crossover_rate = 80 ,
167
- mutation_rate = 7 ,
168
- generations = 2 ,
169
- number_elites = 1 ,
170
- threads = 8 ,
171
- gpus = "0" ,
172
- proto = "/tmp/tuner.txt" ,
173
- restore_from_proto = False ,
174
- restore_number = 10 ,
175
- log_generations = False ,
176
- tuner_min_launch_total_threads = 64 ,
177
- ** kwargs
178
- ):
162
+ def __init__ (self , tc_lang , ** kwargs ):
179
163
# tuner_cache will look like:
180
164
# hash_key -> {"forward": options1, "backward": options2}
181
165
self .tuner_cache = {}
182
166
self .kwargs = kwargs
183
167
self .tc_lang = tc_lang
184
168
self .autotuner = ATenAutotuner (tc_lang )
185
- self .set_autotuner_settings (
186
- pop_size , crossover_rate , mutation_rate , generations , number_elites ,
187
- threads , gpus , proto , restore_from_proto , restore_number ,
188
- log_generations , tuner_min_launch_total_threads
189
- )
169
+ self .set_autotuner_parameters (** kwargs )
190
170
191
- def set_autotuner_settings (
192
- self , pop_size , crossover_rate , mutation_rate , generations , number_elites ,
193
- threads , gpus , proto , restore_from_proto , restore_number , log_generations ,
194
- tuner_min_launch_total_threads ,
171
+ def set_autotuner_parameters (
172
+ self , pop_size = 10 , crossover_rate = 80 , mutation_rate = 7 , generations = 2 ,
173
+ number_elites = 1 , threads = 8 , gpus = "0" , proto = "/tmp/tuner.txt" ,
174
+ restore_from_proto = False , restore_number = 10 , log_generations = False ,
175
+ tuner_min_launch_total_threads = 64 , ** kwargs
195
176
):
196
177
self .autotuner .pop_size (pop_size )
197
178
self .autotuner .crossover_rate (crossover_rate )
@@ -602,7 +583,7 @@ def autotune(self, *inputs, **kwargs):
602
583
else :
603
584
# we do the init again so that the autotuner parameters are updated
604
585
# properly if users change them
605
- self .tuner .__init__ ( self . lang , ** kwargs )
586
+ self .tuner .set_autotuner_parameters ( ** kwargs )
606
587
return self .tuner .autotune (* inputs , ** kwargs )
607
588
608
589
###############################################################################
0 commit comments