|
1 | 1 | """Module for kernel tuner utility functions.""" |
| 2 | +import ast |
2 | 3 | import errno |
3 | 4 | import json |
4 | 5 | import logging |
5 | 6 | import os |
6 | 7 | import re |
7 | 8 | import sys |
8 | 9 | import tempfile |
| 10 | +import textwrap |
9 | 11 | import time |
10 | 12 | import warnings |
11 | | -from inspect import signature |
| 13 | +from inspect import signature, getsource |
12 | 14 | from types import FunctionType |
13 | 15 | from typing import Optional, Union |
14 | 16 |
|
@@ -1037,8 +1039,97 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio |
1037 | 1039 | return parsed_restrictions |
1038 | 1040 |
|
1039 | 1041 |
|
| 1042 | +def get_all_lambda_asts(func): |
| 1043 | + """ |
| 1044 | + Extracts the AST nodes of all lambda functions defined on the same line as func. |
| 1045 | + Args: |
| 1046 | + func: A lambda function object. |
| 1047 | + Returns: |
| 1048 | + A list of all ast.Lambda node objects on the line where func is defined. |
| 1049 | + Raises: |
| 1050 | + ValueError: If the source can't be retrieved or no lambda is found. |
| 1051 | + """ |
| 1052 | + |
| 1053 | + res = [] |
| 1054 | + try: |
| 1055 | + source = inspect.getsource(func) |
| 1056 | + source = textwrap.dedent(source).strip() |
| 1057 | + parsed = ast.parse(source) |
| 1058 | + |
| 1059 | + # Find the Lambda node |
| 1060 | + for node in ast.walk(parsed): |
| 1061 | + if isinstance(node, ast.Lambda): |
| 1062 | + res.append(node) |
| 1063 | + if not res: |
| 1064 | + raise ValueError("No lambda node found in the source.") |
| 1065 | + except OSError: |
| 1066 | + raise ValueError("Could not retrieve source. Is this defined interactively or dynamically?") |
| 1067 | + return res |
| 1068 | + |
| 1069 | + |
| 1070 | +class ConstraintLambdaTransformer(ast.NodeTransformer): |
| 1071 | + """ |
| 1072 | + Replaces any `NAME['string']` subscript with just `'string'`, if `NAME` |
| 1073 | + matches the lambda argument name. |
| 1074 | + """ |
| 1075 | + def __init__(self, dict_arg_name): |
| 1076 | + self.dict_arg_name = dict_arg_name |
| 1077 | + |
| 1078 | + def visit_Subscript(self, node): |
| 1079 | + # We only replace subscript expressions of the form <dict_arg_name>['some_string'] |
| 1080 | + if (isinstance(node.value, ast.Name) |
| 1081 | + and node.value.id == self.dict_arg_name |
| 1082 | + and isinstance(node.slice, ast.Constant) |
| 1083 | + and isinstance(node.slice.value, str)): |
| 1084 | + # Replace `dict_arg_name['some_key']` with the string used as key |
| 1085 | + return ast.Name(node.slice.value) |
| 1086 | + return self.generic_visit(node) |
| 1087 | + |
| 1088 | + |
| 1089 | +def unparse_constraint_lambda(lambda_ast): |
| 1090 | + """ |
| 1091 | + Parse the lambda function to replace accesses to tunable parameter dict |
| 1092 | + Returns string body of the rewritten lambda function |
| 1093 | + """ |
| 1094 | + args = lambda_ast.args |
| 1095 | + body = lambda_ast.args |
| 1096 | + |
| 1097 | + # Kernel Tuner only allows constraint lambdas with a single argument |
| 1098 | + arg = args.args[0].arg |
| 1099 | + |
| 1100 | + # Create transformer that replaces accesses to tunable parameter dict |
| 1101 | + # with simply the name of the tunable parameter |
| 1102 | + transformer = ConstraintLambdaTransformer(arg) |
| 1103 | + new_lambda_ast = transformer.visit(lambda_ast) |
| 1104 | + |
| 1105 | + rewritten_lambda_body_as_string = ast.unparse(new_lambda_ast.body).strip() |
| 1106 | + |
| 1107 | + return rewritten_lambda_body_as_string |
| 1108 | + |
| 1109 | + |
| 1110 | +def convert_constraint_lambdas(restrictions): |
| 1111 | + """ extract and convert all constraint lambdas from the restrictions """ |
| 1112 | + parse_callables_once = True |
| 1113 | + res = [] |
| 1114 | + for c in restrictions: |
| 1115 | + if isinstance(c, str): |
| 1116 | + res.append(c) |
| 1117 | + if callable(c) and parse_callables_once: |
| 1118 | + lambda_asts = get_all_lambda_asts(c) |
| 1119 | + |
| 1120 | + for lambda_ast in lambda_asts: |
| 1121 | + new_c = unparse_constraint_lambda(lambda_ast) |
| 1122 | + res.append(new_c) |
| 1123 | + |
| 1124 | + parse_callables_once = False |
| 1125 | + return res |
| 1126 | + |
| 1127 | + |
1040 | 1128 | def compile_restrictions(restrictions: list, tune_params: dict, monolithic = False, format = None, try_to_constraint = True) -> list[tuple[Union[str, Constraint, FunctionType], list[str]]]: |
1041 | 1129 | """Parses restrictions from a list of strings into a list of strings, Functions, or Constraints (if `try_to_constraint`) and parameters used, or a single Function if monolithic is true.""" |
| 1130 | + |
| 1131 | + restrictions = convert_constraint_lambdas(restrictions) |
| 1132 | + |
1042 | 1133 | # filter the restrictions to get only the strings |
1043 | 1134 | restrictions_str, restrictions_ignore = [], [] |
1044 | 1135 | for r in restrictions: |
|
0 commit comments