|
| 1 | +from collections.abc import Callable |
| 2 | + |
1 | 3 | from posthog.hogql.ast import ArrayType, BooleanType, StringType |
2 | 4 | from posthog.hogql.base import UnknownType |
3 | 5 |
|
4 | 6 | from .core import HogQLFunctionMeta |
5 | 7 |
|
| 8 | +COMBINATORS = { |
| 9 | + "If": {"allowedSuffixes": [], "argMap": lambda min, max: [min + 1, max + 1]}, |
| 10 | + "Array": {"allowedSuffixes": ["If", "OrDefault", "OrNull"], "argMap": lambda min, max: [min, max]}, |
| 11 | + "Map": {"allowedSuffixes": ["If", "OrDefault", "OrNull"], "argMap": lambda min, max: [min, max]}, |
| 12 | + "State": {"allowedSuffixes": ["If", "OrDefault", "OrNull"], "argMap": lambda min, max: [min, max]}, |
| 13 | + "Merge": {"allowedSuffixes": ["If", "OrDefault", "OrNull"], "argMap": lambda min, max: [min, max]}, |
| 14 | + "ForEach": {"allowedSuffixes": ["If", "OrDefault", "OrNull"], "argMap": lambda min, max: [min, max]}, |
| 15 | + "OrDefault": {"allowedSuffixes": ["If"], "argMap": lambda min, max: [min, max]}, |
| 16 | + "OrNull": {"allowedSuffixes": ["If"], "argMap": lambda min, max: [min, max]}, |
| 17 | + "ArgMin": {"allowedSuffixes": ["If", "OrDefault", "OrNull"], "argMap": lambda min, max: [min + 1, max + 1]}, |
| 18 | + "ArgMax": {"allowedSuffixes": ["If", "OrDefault", "OrNull"], "argMap": lambda min, max: [min + 1, max + 1]}, |
| 19 | +} |
| 20 | + |
| 21 | +COMBINATOR_AGGREGATIONS = { |
| 22 | + "avg": HogQLFunctionMeta("avg", 1, 1, aggregate=True), |
| 23 | + "sum": HogQLFunctionMeta("sum", 1, 1, aggregate=True), |
| 24 | + "min": HogQLFunctionMeta("min", 1, 1, aggregate=True), |
| 25 | + "max": HogQLFunctionMeta("max", 1, 1, aggregate=True), |
| 26 | + "count": HogQLFunctionMeta("count", 0, 1, aggregate=True), |
| 27 | + "countDistinct": HogQLFunctionMeta("countDistinct", 1, 1, aggregate=True), |
| 28 | + "median": HogQLFunctionMeta("median", 1, 1, aggregate=True), |
| 29 | +} |
| 30 | + |
| 31 | + |
| 32 | +def _generate_suffix_combinations( |
| 33 | + base_name: str, base_meta: HogQLFunctionMeta, current_suffixes: list[str] | None = None |
| 34 | +): |
| 35 | + result = {} |
| 36 | + |
| 37 | + if current_suffixes is None: |
| 38 | + current_suffixes = [] |
| 39 | + |
| 40 | + if current_suffixes: |
| 41 | + func_name = base_name + "".join(current_suffixes) |
| 42 | + # Calculate new parameter ranges based on suffix rules |
| 43 | + min_params, max_params = base_meta.min_args, base_meta.max_args |
| 44 | + for suffix in current_suffixes: |
| 45 | + if suffix in COMBINATORS: |
| 46 | + arg_map: Callable[[int, int | None], list[int]] = COMBINATORS[suffix]["argMap"] # type: ignore |
| 47 | + min_params, max_params = arg_map(min_params, max_params) |
| 48 | + |
| 49 | + result[func_name] = HogQLFunctionMeta(func_name, min_params, max_params, aggregate=True) |
| 50 | + |
| 51 | + if not current_suffixes: |
| 52 | + available_suffixes = list(COMBINATORS.keys()) |
| 53 | + else: |
| 54 | + last_suffix = current_suffixes[-1] |
| 55 | + allowed_suffixes: list[str] = COMBINATORS.get(last_suffix, {}).get("allowedSuffixes", []) # type: ignore |
| 56 | + available_suffixes = allowed_suffixes |
| 57 | + |
| 58 | + for suffix in available_suffixes: |
| 59 | + if suffix not in current_suffixes: |
| 60 | + nested_result = _generate_suffix_combinations(base_name, base_meta, [*current_suffixes, suffix]) |
| 61 | + result.update(nested_result) |
| 62 | + |
| 63 | + return result |
| 64 | + |
| 65 | + |
| 66 | +def generate_combinator_suffix_combinations(): |
| 67 | + result = {} |
| 68 | + |
| 69 | + for base_name, base_meta in COMBINATOR_AGGREGATIONS.items(): |
| 70 | + combinations = _generate_suffix_combinations(base_name, base_meta) |
| 71 | + result.update(combinations) |
| 72 | + |
| 73 | + return result |
| 74 | + |
| 75 | + |
6 | 76 | # Permitted HogQL aggregations |
7 | 77 | # Keep in sync with the posthog.com repository: contents/docs/sql/aggregations.mdx |
8 | 78 | HOGQL_AGGREGATIONS: dict[str, HogQLFunctionMeta] = { |
| 79 | + # Generated combinator functions |
| 80 | + **generate_combinator_suffix_combinations(), |
9 | 81 | # Standard aggregate functions |
10 | 82 | "count": HogQLFunctionMeta("count", 0, 1, aggregate=True, case_sensitive=False), |
11 | 83 | "countIf": HogQLFunctionMeta("countIf", 1, 2, aggregate=True), |
|
0 commit comments