Skip to content

Commit 13f61ec

Browse files
add support for lambdas
1 parent af88572 commit 13f61ec

File tree

1 file changed

+92
-1
lines changed

1 file changed

+92
-1
lines changed

kernel_tuner/util.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Module for kernel tuner utility functions."""
2+
import ast
23
import errno
34
import json
45
import logging
56
import os
67
import re
78
import sys
89
import tempfile
10+
import textwrap
911
import time
1012
import warnings
11-
from inspect import signature
13+
from inspect import signature, getsource
1214
from types import FunctionType
1315
from typing import Optional, Union
1416

@@ -1037,8 +1039,97 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio
10371039
return parsed_restrictions
10381040

10391041

1042+
def get_all_lambda_asts(func):
1043+
"""
1044+
Extracts the AST nodes of all lambda functions defined on the same line as func.
1045+
Args:
1046+
func: A lambda function object.
1047+
Returns:
1048+
A list of all ast.Lambda node objects on the line where func is defined.
1049+
Raises:
1050+
ValueError: If the source can't be retrieved or no lambda is found.
1051+
"""
1052+
1053+
res = []
1054+
try:
1055+
source = inspect.getsource(func)
1056+
source = textwrap.dedent(source).strip()
1057+
parsed = ast.parse(source)
1058+
1059+
# Find the Lambda node
1060+
for node in ast.walk(parsed):
1061+
if isinstance(node, ast.Lambda):
1062+
res.append(node)
1063+
if not res:
1064+
raise ValueError("No lambda node found in the source.")
1065+
except OSError:
1066+
raise ValueError("Could not retrieve source. Is this defined interactively or dynamically?")
1067+
return res
1068+
1069+
1070+
class ConstraintLambdaTransformer(ast.NodeTransformer):
1071+
"""
1072+
Replaces any `NAME['string']` subscript with just `'string'`, if `NAME`
1073+
matches the lambda argument name.
1074+
"""
1075+
def __init__(self, dict_arg_name):
1076+
self.dict_arg_name = dict_arg_name
1077+
1078+
def visit_Subscript(self, node):
1079+
# We only replace subscript expressions of the form <dict_arg_name>['some_string']
1080+
if (isinstance(node.value, ast.Name)
1081+
and node.value.id == self.dict_arg_name
1082+
and isinstance(node.slice, ast.Constant)
1083+
and isinstance(node.slice.value, str)):
1084+
# Replace `dict_arg_name['some_key']` with the string used as key
1085+
return ast.Name(node.slice.value)
1086+
return self.generic_visit(node)
1087+
1088+
1089+
def unparse_constraint_lambda(lambda_ast):
1090+
"""
1091+
Parse the lambda function to replace accesses to tunable parameter dict
1092+
Returns string body of the rewritten lambda function
1093+
"""
1094+
args = lambda_ast.args
1095+
body = lambda_ast.args
1096+
1097+
# Kernel Tuner only allows constraint lambdas with a single argument
1098+
arg = args.args[0].arg
1099+
1100+
# Create transformer that replaces accesses to tunable parameter dict
1101+
# with simply the name of the tunable parameter
1102+
transformer = ConstraintLambdaTransformer(arg)
1103+
new_lambda_ast = transformer.visit(lambda_ast)
1104+
1105+
rewritten_lambda_body_as_string = ast.unparse(new_lambda_ast.body).strip()
1106+
1107+
return rewritten_lambda_body_as_string
1108+
1109+
1110+
def convert_constraint_lambdas(restrictions):
1111+
""" extract and convert all constraint lambdas from the restrictions """
1112+
parse_callables_once = True
1113+
res = []
1114+
for c in restrictions:
1115+
if isinstance(c, str):
1116+
res.append(c)
1117+
if callable(c) and parse_callables_once:
1118+
lambda_asts = get_all_lambda_asts(c)
1119+
1120+
for lambda_ast in lambda_asts:
1121+
new_c = unparse_constraint_lambda(lambda_ast)
1122+
res.append(new_c)
1123+
1124+
parse_callables_once = False
1125+
return res
1126+
1127+
10401128
def compile_restrictions(restrictions: list, tune_params: dict, monolithic = False, format = None, try_to_constraint = True) -> list[tuple[Union[str, Constraint, FunctionType], list[str]]]:
10411129
"""Parses restrictions from a list of strings into a list of strings, Functions, or Constraints (if `try_to_constraint`) and parameters used, or a single Function if monolithic is true."""
1130+
1131+
restrictions = convert_constraint_lambdas(restrictions)
1132+
10421133
# filter the restrictions to get only the strings
10431134
restrictions_str, restrictions_ignore = [], []
10441135
for r in restrictions:

0 commit comments

Comments
 (0)