Skip to content

Commit 5f3b6fc

Browse files
committed
Updated tune_kernel_T1 to be more broadly applicable
1 parent e4af9f7 commit 5f3b6fc

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

kernel_tuner/interface.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -870,29 +870,37 @@ def tune_kernel_T1(
870870
simulation_mode=False,
871871
output_T4=True,
872872
iterations=7,
873-
strategy_options=None,
874-
):
875-
"""Call the tune function with a T1 input file."""
873+
device=None,
874+
strategy: str=None,
875+
strategy_options: dict={},
876+
) -> tuple:
877+
"""
878+
Call the tune function with a T1 input file.
879+
880+
The device, strategy and strategy_options can be overridden by passing a strategy name and options, otherwise the input file specification is used.
881+
"""
876882
inputs = get_input_file(input_filepath)
877883
kernelspec: dict = inputs["KernelSpecification"]
878884
kernel_name: str = kernelspec["KernelName"]
879885
kernel_filepath = Path(kernelspec["KernelFile"])
880886
kernel_source = (
881-
kernel_filepath if kernel_filepath.exists() else Path(input_filepath).parent.parent / kernel_filepath
887+
kernel_filepath if kernel_filepath.exists() else Path(input_filepath).parent / kernel_filepath
888+
)
889+
kernel_source = (
890+
kernel_source if kernel_source.exists() else Path(input_filepath).parent.parent / kernel_filepath
882891
)
883892
assert kernel_source.exists(), f"KernelFile '{kernel_source}' does not exist at {kernel_source.resolve()}"
884893
language: str = kernelspec["Language"]
885894
problem_size = kernelspec["ProblemSize"]
886-
device = kernelspec["Device"]["Name"]
887-
strategy = inputs["Search"]["Name"]
888-
if "Attributes" in inputs["Search"]:
889-
strategy_options = {}
890-
for attribute in inputs["Search"]["Attributes"]:
891-
strategy_options[attribute["Name"]] = attribute["Value"]
895+
if device is None:
896+
device = kernelspec["Device"]["Name"]
897+
if strategy is None:
898+
strategy = inputs["Search"]["Name"]
899+
if "Attributes" in inputs["Search"]:
900+
for attribute in inputs["Search"]["Attributes"]:
901+
strategy_options[attribute["Name"]] = attribute["Value"]
892902
if "Budget" in inputs:
893903
budget = inputs["Budget"][0]
894-
if strategy_options is None:
895-
strategy_options = {}
896904
if budget["Type"] == "ConfigurationCount":
897905
strategy_options["max_fevals"] = budget["BudgetValue"]
898906
elif budget["Type"] == "TuningDuration":

0 commit comments

Comments
 (0)