|
2 | 2 |
|
3 | 3 | import csv |
4 | 4 | import itertools |
5 | | -import sys |
6 | 5 | from typing import TypeVar, Generic, Type, overload, Union, Callable, List, Dict, Any, KeysView, Optional, OrderedDict, \ |
7 | | - cast |
| 6 | + cast, Tuple, Sequence, Protocol |
8 | 7 |
|
9 | | -if sys.version_info[1] < 8: |
10 | | - from typing_extensions import Protocol |
11 | | -else: |
12 | | - from typing import Protocol |
| 8 | +from typing_extensions import ParamSpec |
| 9 | + |
| 10 | +from ..core import Range |
13 | 11 |
|
14 | 12 |
|
15 | 13 | # from https://stackoverflow.com/questions/47965083/comparable-types-with-mypy |
@@ -152,3 +150,90 @@ def first(self, err="no elements in list") -> PartsTableRow: |
152 | 150 | if not self.rows: |
153 | 151 | raise IndexError(err) |
154 | 152 | return self.rows[0] |
| 153 | + |
| 154 | + |
| 155 | +UserFnMetaParams = ParamSpec('UserFnMetaParams') |
| 156 | +UserFnType = TypeVar('UserFnType', bound=Callable, covariant=True) |
| 157 | +class UserFnSerialiable(Protocol[UserFnMetaParams, UserFnType]): |
| 158 | + """A protocol that marks functions as usable in deserialize, that they have been registered.""" |
| 159 | + _is_serializable: None # guard attribute |
| 160 | + |
| 161 | + def __call__(self, *args: UserFnMetaParams.args, **kwargs: UserFnMetaParams.kwargs) -> UserFnType: ... |
| 162 | + __name__: str |
| 163 | + |
| 164 | + |
| 165 | +class ExperimentalUserFnPartsTable(PartsTable): |
| 166 | + """A PartsTable that can take in a user-defined function for filtering and (possibly) other operations. |
| 167 | + These functions are serialized to a string by an internal name (cannot execute arbitrary code, |
| 168 | + bounded to defined functions in the codebase), and some arguments can be serialized with the name |
| 169 | + (think partial(...)). |
| 170 | + Functions must be pre-registered using the @ExperimentalUserFnPartsTable.user_fn(...) decorator, |
| 171 | + non-pre-registered functions will not be available. |
| 172 | +
|
| 173 | + This is intended to support searches on parts tables that are cross-coupled across multiple parameters, |
| 174 | + but still restricted to within on table (e.g., no cross-optimizing RC filters). |
| 175 | +
|
| 176 | + EXPERIMENTAL - subject to change without notice.""" |
| 177 | + |
| 178 | + _FN_SERIALIZATION_SEPARATOR = ";" |
| 179 | + |
| 180 | + _user_fns: Dict[str, Tuple[Callable, Sequence[Type]]] = {} # name -> fn, [arg types] |
| 181 | + _fn_name_dict: Dict[Callable, str] = {} |
| 182 | + |
| 183 | + @staticmethod |
| 184 | + def user_fn(param_types: Sequence[Type] = []) -> Callable[[Callable[UserFnMetaParams, UserFnType]], |
| 185 | + UserFnSerialiable[UserFnMetaParams, UserFnType]]: |
| 186 | + def decorator(fn: Callable[UserFnMetaParams, UserFnType]) -> UserFnSerialiable[UserFnMetaParams, UserFnType]: |
| 187 | + """Decorator to register a user function that can be used in ExperimentalUserFnPartsTable.""" |
| 188 | + if fn.__name__ in ExperimentalUserFnPartsTable._user_fns or fn in ExperimentalUserFnPartsTable._fn_name_dict: |
| 189 | + raise ValueError(f"Function {fn.__name__} already registered.") |
| 190 | + ExperimentalUserFnPartsTable._user_fns[fn.__name__] = (fn, param_types) |
| 191 | + ExperimentalUserFnPartsTable._fn_name_dict[fn] = fn.__name__ |
| 192 | + return fn # type: ignore |
| 193 | + return decorator |
| 194 | + |
| 195 | + @classmethod |
| 196 | + def serialize_fn(cls, fn: UserFnSerialiable[UserFnMetaParams, UserFnType], |
| 197 | + *args: UserFnMetaParams.args, **kwargs: UserFnMetaParams.kwargs) -> str: |
| 198 | + """Serializes a user function to a string.""" |
| 199 | + assert not kwargs, "kwargs not supported in serialization" |
| 200 | + if fn not in cls._fn_name_dict: |
| 201 | + raise ValueError(f"Function {fn} not registered.") |
| 202 | + fn_ctor, fn_argtypes = cls._user_fns[fn.__name__] |
| 203 | + def serialize_arg(tpe: Type, val: Any) -> str: |
| 204 | + assert isinstance(val, tpe), f"in serialize {val}, expected {tpe}, got {type(val)}" |
| 205 | + if tpe is bool: |
| 206 | + return str(val) |
| 207 | + elif tpe is int: |
| 208 | + return str(val) |
| 209 | + elif tpe is float: |
| 210 | + return str(val) |
| 211 | + elif tpe is Range: |
| 212 | + return f"({val.lower},{val.upper})" |
| 213 | + else: |
| 214 | + raise TypeError(f"cannot serialize type {tpe} in user function serialization") |
| 215 | + serialized_args = [serialize_arg(tpe, arg) for tpe, arg in zip(fn_argtypes, args)] |
| 216 | + return cls._FN_SERIALIZATION_SEPARATOR.join([fn.__name__] + serialized_args) |
| 217 | + |
| 218 | + @classmethod |
| 219 | + def deserialize_fn(cls, serialized: str) -> Callable: |
| 220 | + """Deserializes a user function from a string.""" |
| 221 | + split = serialized.split(cls._FN_SERIALIZATION_SEPARATOR) |
| 222 | + if split[0] not in cls._user_fns: |
| 223 | + raise ValueError(f"Function {serialized} not registered.") |
| 224 | + fn_ctor, fn_argtypes = cls._user_fns[split[0]] |
| 225 | + assert len(split) == len(fn_argtypes) + 1 |
| 226 | + def deserialize_arg(tpe: Type, val: str) -> Any: |
| 227 | + if tpe is bool: |
| 228 | + return val == 'True' |
| 229 | + elif tpe is int: |
| 230 | + return int(val) |
| 231 | + elif tpe is float: |
| 232 | + return float(val) |
| 233 | + elif tpe is Range: |
| 234 | + parts = val[1:-1].split(",") |
| 235 | + return Range(float(parts[0]), float(parts[1])) # type: ignore |
| 236 | + else: |
| 237 | + raise TypeError(f"cannot deserialize type {tpe} in user function serialization") |
| 238 | + deserialized_args = [deserialize_arg(tpe, arg) for tpe, arg in zip(fn_argtypes, split[1:])] |
| 239 | + return fn_ctor(*deserialized_args) |
0 commit comments