Skip to content

Commit e8ff6e7

Browse files
committed
Improved dynamic import of modules for custom strategies
1 parent dbcb89d commit e8ff6e7

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

kernel_tuner/file_utils.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -305,16 +305,23 @@ def store_metadata_file(metadata_filename: str):
305305

306306
def import_class_from_file(file_path: Path, class_name):
307307
"""Import a class from a file."""
308-
module_name = file_path.stem
309-
spec = spec_from_file_location(module_name, file_path)
310-
if spec is None:
311-
raise ImportError(f"Could not load spec from {file_path}")
312-
313-
# create a module from the spec and execute it
314-
module = module_from_spec(spec)
315-
spec.loader.exec_module(module)
316-
if not hasattr(module, class_name):
317-
raise ImportError(f"Module '{module_name}' has no class '{class_name}'")
308+
309+
def load_module(module_name):
310+
spec = spec_from_file_location(module_name, file_path)
311+
if spec is None:
312+
raise ImportError(f"Could not load spec from {file_path}")
313+
314+
# create a module from the spec and execute it
315+
module = module_from_spec(spec)
316+
spec.loader.exec_module(module)
317+
if not hasattr(module, class_name):
318+
raise ImportError(f"Module '{module_name}' has no class '{class_name}'")
319+
return module
320+
321+
try:
322+
module = load_module(file_path.stem)
323+
except ImportError:
324+
module = load_module(f"{file_path.parent.stem}.{file_path.stem}")
318325

319326
# return the class from the module
320327
return getattr(module, class_name)

kernel_tuner/interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,10 +902,9 @@ def tune_kernel_T1(
902902
class_name: str = strategy
903903
assert opt_path.exists(), f"Custom search method path '{opt_path}' does not exist relative to current working directory {Path.cwd()}"
904904
optimizer_class = import_class_from_file(opt_path, class_name)
905-
budget = strategy_options.get("max_fevals", 1e12) # if not set, use a very large number to have it run out at the time limit
906905
filter_keys = ["custom_search_method_path", "max_fevals", "time_limit", "constraint_aware"]
907906
adjusted_strategy_options = {k:v for k, v in strategy_options.items() if k not in filter_keys}
908-
optimizer_instance = optimizer_class(budget=budget, **adjusted_strategy_options)
907+
optimizer_instance = optimizer_class(**adjusted_strategy_options)
909908
strategy = OptAlgWrapper(optimizer_instance)
910909

911910
# set the cache path
@@ -973,6 +972,8 @@ def tune_kernel_T1(
973972
elif arg["MemoryType"] == "Scalar":
974973
if arg["Type"] == "float":
975974
argument = numpy.float32(arg["FillValue"])
975+
elif arg["Type"] == "int32":
976+
argument = numpy.int32(arg["FillValue"])
976977
else:
977978
raise NotImplementedError()
978979
if argument is not None:

0 commit comments

Comments
 (0)