@@ -47,6 +47,7 @@ def __init__(
47
47
restrictions ,
48
48
max_threads : int ,
49
49
block_size_names = default_block_size_names ,
50
+ defer_construction = False ,
50
51
build_neighbors_index = False ,
51
52
neighbor_method = None ,
52
53
from_cache : dict = None ,
@@ -62,6 +63,7 @@ def __init__(
62
63
Hamming: any parameter config with 1 different parameter value is a neighbor
63
64
Optionally sort the searchspace by the order in which the parameter values were specified. By default, sort goes from first to last parameter, to reverse this use sort_last_param_first.
64
65
Optionally an imported cache can be used instead with `from_cache`, in which case the `tune_params`, `restrictions` and `max_threads` arguments can be set to None, and construction is skipped.
66
+ Optionally construction can be deffered to a later time by setting `defer_construction` to True, in which case the searchspace is not built on instantiation (experimental).
65
67
"""
66
68
# check the arguments
67
69
if from_cache is not None :
@@ -76,7 +78,8 @@ def __init__(
76
78
framework_l = framework .lower ()
77
79
restrictions = restrictions if restrictions is not None else []
78
80
self .tune_params = tune_params
79
- self .tune_params_pyatf = None
81
+ self .max_threads = max_threads
82
+ self .block_size_names = block_size_names
80
83
self ._tensorspace = None
81
84
self .tensor_dtype = torch .float32 if torch_available else None
82
85
self .tensor_device = torch .device ("cpu" ) if torch_available else None
@@ -160,17 +163,19 @@ def __init__(
160
163
else :
161
164
raise ValueError (f"Solver method { solver_method } not recognized." )
162
165
163
- # build the search space
164
- self .list , self .__dict , self .size = searchspace_builder (block_size_names , max_threads , solver )
166
+ if not defer_construction :
167
+ # build the search space
168
+ self .list , self .__dict , self .size = searchspace_builder (block_size_names , max_threads , solver )
165
169
166
170
# finalize construction
167
- self .__numpy = None
168
- self .num_params = len (self .tune_params )
169
- self .indices = np .arange (self .size )
170
- if neighbor_method is not None and neighbor_method != "Hamming" :
171
- self .__prepare_neighbors_index ()
172
- if build_neighbors_index :
173
- self .neighbors_index = self .__build_neighbors_index (neighbor_method )
171
+ if not defer_construction :
172
+ self .__numpy = None
173
+ self .num_params = len (self .tune_params )
174
+ self .indices = np .arange (self .size )
175
+ if neighbor_method is not None and neighbor_method != "Hamming" :
176
+ self .__prepare_neighbors_index ()
177
+ if build_neighbors_index :
178
+ self .neighbors_index = self .__build_neighbors_index (neighbor_method )
174
179
175
180
# def __build_searchspace_ortools(self, block_size_names: list, max_threads: int) -> Tuple[List[tuple], np.ndarray, dict, int]:
176
181
# # Based on https://developers.google.com/optimization/cp/cp_solver#python_2
@@ -318,14 +323,15 @@ def all_smt(formula, keys) -> list:
318
323
319
324
return self .__parameter_space_list_to_lookup_and_return_type (parameter_space_list )
320
325
321
- def __build_searchspace_pyATF (self , block_size_names : list , max_threads : int , solver : Solver ):
322
- """Builds the searchspace using pyATF."""
323
- from pyatf import TP , Interval , Set , Tuner
324
- from pyatf .cost_functions .generic import CostFunction
325
- from pyatf .search_techniques import Exhaustive
326
+ def get_tune_params_pyatf (self , block_size_names : list = None , max_threads : int = None ):
327
+ """Convert the tune_params and restrictions to pyATF tunable parameters."""
328
+ from pyatf import TP , Interval , Set
326
329
327
- # Define a bogus cost function
328
- costfunc = CostFunction (":" ) # bash no-op
330
+ # if block_size_names or max_threads are not specified, use the defaults
331
+ if block_size_names is None :
332
+ block_size_names = self .block_size_names
333
+ if max_threads is None :
334
+ max_threads = self .max_threads
329
335
330
336
# add the Kernel Tuner default blocksize threads restrictions
331
337
assert isinstance (self .restrictions , list )
@@ -359,27 +365,36 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
359
365
registered_restrictions .append (index )
360
366
361
367
# define the Tunable Parameters
362
- def get_params ():
363
- params = list ()
364
- for index , (key , values ) in enumerate (self .tune_params .items ()):
365
- vi = get_interval (values )
366
- vals = (
367
- Interval (vi [0 ], vi [1 ], vi [2 ]) if vi is not None and vi [2 ] != 0 else Set (* np .array (values ).flatten ())
368
- )
369
- constraint = res_dict .get (key , None )
370
- constraint_source = None
371
- if constraint is not None :
372
- constraint , constraint_source = constraint
373
- # in case of a leftover monolithic restriction, append at the last parameter
374
- if index == len (self .tune_params ) - 1 and len (res_dict ) == 0 and len (self .restrictions ) == 1 :
375
- res , params , source = self .restrictions [0 ]
376
- assert callable (res )
377
- constraint = res
378
- params .append (TP (key , vals , constraint , constraint_source ))
379
- return params
368
+ params = list ()
369
+ for index , (key , values ) in enumerate (self .tune_params .items ()):
370
+ vi = get_interval (values )
371
+ vals = (
372
+ Interval (vi [0 ], vi [1 ], vi [2 ]) if vi is not None and vi [2 ] != 0 else Set (* np .array (values ).flatten ())
373
+ )
374
+ constraint = res_dict .get (key , None )
375
+ constraint_source = None
376
+ if constraint is not None :
377
+ constraint , constraint_source = constraint
378
+ # in case of a leftover monolithic restriction, append at the last parameter
379
+ if index == len (self .tune_params ) - 1 and len (res_dict ) == 0 and len (self .restrictions ) == 1 :
380
+ res , params , source = self .restrictions [0 ]
381
+ assert callable (res )
382
+ constraint = res
383
+ params .append (TP (key , vals , constraint , constraint_source ))
384
+ return params
385
+
386
+
387
+ def __build_searchspace_pyATF (self , block_size_names : list , max_threads : int , solver : Solver ):
388
+ """Builds the searchspace using pyATF."""
389
+ from pyatf import Tuner
390
+ from pyatf .cost_functions .generic import CostFunction
391
+ from pyatf .search_techniques import Exhaustive
392
+
393
+ # Define a bogus cost function
394
+ costfunc = CostFunction (":" ) # bash no-op
380
395
381
396
# set data
382
- self .tune_params_pyatf = get_params ( )
397
+ self .tune_params_pyatf = self . get_tune_params_pyatf ( block_size_names , max_threads )
383
398
384
399
# tune
385
400
_ , _ , tuning_data = (
0 commit comments