Skip to content

Commit 92e453f

Browse files
authored
Adding support for user defined functions in PyDough (#380)
Resolves #382. Adds ability to define additional function operators in the PyDough metadata, with three formats at first: - SQL Alias: calls to the function get translated directly 1:1 to a function call in the database's SQL dialect (scalar or aggregation) - SQL Window Alias: Same as SQL Alias but only for window functions - SQL Macro: calls to the function inject their arguments' SQL texts into a Python format string Also updated docstrings throughout the codebase to avoid using types of the inputs/outputs since those should be part of the type annotations (not the docstrings).
1 parent f122cdb commit 92e453f

File tree

70 files changed

+4284
-277
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+4284
-277
lines changed

documentation/metadata.md

Lines changed: 249 additions & 5 deletions
Large diffs are not rendered by default.

pydough/configs/session.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def metadata(self) -> GraphMetadata | None:
5454
Get the active metadata graph.
5555
5656
Returns:
57-
GraphMetadata: The active metadata graph.
57+
The active metadata graph.
5858
"""
5959
return self._metadata
6060

@@ -64,7 +64,7 @@ def metadata(self, graph: GraphMetadata | None) -> None:
6464
Set the active metadata graph.
6565
6666
Args:
67-
graph (GraphMetadata | None): The metadata graph to set.
67+
graph: The metadata graph to set.
6868
"""
6969
self._metadata = graph
7070

@@ -74,7 +74,7 @@ def config(self) -> PyDoughConfigs:
7474
Get the active PyDough configuration.
7575
7676
Returns:
77-
PyDoughConfigs: The active PyDough configuration.
77+
The active PyDough configuration.
7878
"""
7979
return self._config
8080

@@ -84,7 +84,7 @@ def config(self, config: PyDoughConfigs) -> None:
8484
Set the active PyDough configuration.
8585
8686
Args:
87-
config (PyDoughConfigs): The PyDough configuration to set.
87+
`config`: The PyDough configuration to set.
8888
"""
8989
self._config = config
9090

@@ -94,7 +94,7 @@ def database(self) -> DatabaseContext:
9494
Get the active database context.
9595
9696
Returns:
97-
DatabaseContext: The active database context.
97+
The active database context.
9898
"""
9999
return self._database
100100

@@ -104,7 +104,7 @@ def database(self, context: DatabaseContext) -> None:
104104
Set the active database context.
105105
106106
Args:
107-
context (DatabaseContext): The database context to set.
107+
`context`: The database context to set.
108108
"""
109109
self._database = context
110110

@@ -114,13 +114,13 @@ def connect_database(self, database_name: str, **kwargs) -> DatabaseContext:
114114
the corresponding context in case the user wants/needs to modify it.
115115
116116
Args:
117-
database_name (str): The name of the database to connect to.
117+
`database_name`: The name of the database to connect to.
118118
**kwargs: Additional keyword arguments to pass to the connection.
119119
All arguments must be accepted using the supported connect API
120120
for the dialect. Most likely the database path will be required.
121121
122122
Returns:
123-
DatabaseContext: The newly created database context.
123+
The newly created database context.
124124
"""
125125
context: DatabaseContext = load_database_context(database_name, **kwargs)
126126
self.database = context
@@ -135,13 +135,13 @@ def load_metadata_graph(self, graph_path: str, graph_name: str) -> GraphMetadata
135135
property directly later.
136136
137137
Args:
138-
graph_path (str): The path to load the graph. At this time this must be on
138+
`graph_path`: The path to load the graph. At this time this must be on
139139
the user's local file system.
140-
graph_name (str): The name under which to load the graph from the file. This
140+
`graph_name`: The name under which to load the graph from the file. This
141141
is to allow loading multiple graphs from the same json file.
142142
143143
Returns:
144-
GraphMetadata: The loaded metadata graph.
144+
The loaded metadata graph.
145145
"""
146146
graph: GraphMetadata = parse_json_metadata_from_file(graph_path, graph_name)
147147
self.metadata = graph

pydough/conversion/relational_converter.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,19 +1142,22 @@ def translate_child_pullup(self, node: HybridChildPullUp) -> TranslationOutput:
11421142
return TranslationOutput(child_result.relational_node, new_expressions)
11431143

11441144
def translate_hybridroot(self, context: TranslationOutput) -> TranslationOutput:
1145-
"""Converts a HybridRoot node into a relational tree.
1146-
This method shifts all expressions in the given context back by one level,
1147-
effectively removing the HybridRoot from the context (re-aligning them to the parent context's scope).
1148-
This is needed when stepping out of a nested context.
1149-
The HybridRoot itself does not introduce a new relational operation but serves as a logical boundary.
1150-
This method prepares the context so that subsequent operations refer to the correct expression depth.
1145+
"""
1146+
Converts a HybridRoot node into a relational tree. This method shifts
1147+
all expressions in the given context back by one level, effectively
1148+
removing the HybridRoot from the context (re-aligning them to the
1149+
parent context's scope). This is needed when stepping out of a nested
1150+
context. The HybridRoot itself does not introduce a new relational
1151+
operation but serves as a logical boundary. This method prepares the
1152+
context so that subsequent operations refer to the correct expression
1153+
depth.
11511154
11521155
Args:
1153-
context (TranslationOutput): The current translation context
1154-
associated with the HybridRoot. Must not be None.
1156+
`context`: The current translation context associated with the
1157+
HybridRoot. Must not be None.
11551158
11561159
Returns:
1157-
TranslationOutput: The translated output payload.
1160+
The translated output payload.
11581161
"""
11591162
new_expressions: dict[HybridExpr, ColumnReference] = {}
11601163
for expr, column_ref in context.expressions.items():
@@ -1324,14 +1327,12 @@ def make_relational_ordering(
13241327
Converts a list of collation expressions into a list of ExpressionSortInfo.
13251328
13261329
Args:
1327-
collation (list[CollationExpression]): The list of collation
1328-
expressions to convert.
1329-
expressions (dict[HybridExpr, ColumnReference]): The dictionary of
1330-
expressions to use for the relational ordering.
1330+
`collation`: The list of collation expressions to convert.
1331+
`expressions`: The dictionary of expressions to use for the relational
1332+
ordering.
13311333
13321334
Returns:
1333-
list[ExpressionSortInfo]: The ordering expressions converted into
1334-
ExpressionSortInfo.
1335+
The ordering expressions converted into `ExpressionSortInfo`.
13351336
"""
13361337
orderings: list[ExpressionSortInfo] = []
13371338
for col_expr in collation:

pydough/database_connectors/builtin_databases.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def load_database_context(database_name: str, **kwargs) -> DatabaseContext:
1515
Load the database context with the appropriate connection and dialect.
1616
1717
Args:
18-
database (str): The name of the database to connect to.
19-
**kwargs: Additional keyword arguments to pass to the connection.
18+
`database`: The name of the database to connect to.
19+
`**kwargs`: Additional keyword arguments to pass to the connection.
2020
All arguments must be accepted using the supported connect API
2121
for the dialect.
2222
2323
Returns:
24-
DatabaseContext: The database context object.
24+
The database context object.
2525
"""
2626
supported_databases = {"sqlite"}
2727
connection: DatabaseConnection
@@ -44,7 +44,7 @@ def load_sqlite_connection(**kwargs) -> DatabaseConnection:
4444
around the DB 2.0 connect API.
4545
4646
Returns:
47-
DatabaseConnection: A database connection object for SQLite.
47+
A database connection object for SQLite.
4848
"""
4949
if "database" not in kwargs:
5050
raise ValueError("SQLite connection requires a database path.")

pydough/database_connectors/database_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def execute_query_df(self, sql: str) -> pd.DataFrame:
3838
types are in scope and how we need to test them.
3939
4040
Args:
41-
sql (str): The SQL query to execute.
41+
`sql`: The SQL query to execute.
4242
4343
Returns:
4444
list[pt.Any]: A list of rows returned by the query.
@@ -83,10 +83,10 @@ def from_string(dialect: str) -> "DatabaseDialect":
8383
"""Convert a string to a DatabaseDialect enum.
8484
8585
Args:
86-
dialect (str): The string representation of the dialect.
86+
`dialect`: The string representation of the dialect.
8787
8888
Returns:
89-
DatabaseDialect: The dialect enum.
89+
The dialect enum.
9090
"""
9191
if dialect == "ansi":
9292
return DatabaseDialect.ANSI

pydough/evaluation/evaluate_unqualified.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@ def to_sql(node: UnqualifiedNode, **kwargs) -> str:
107107
Convert the given unqualified tree to a SQL string.
108108
109109
Args:
110-
node (UnqualifiedNode): The node to convert to SQL.
111-
**kwargs: Additional arguments to pass to the conversion for testing.
110+
`node`: The node to convert to SQL.
111+
`**kwargs`: Additional arguments to pass to the conversion for testing.
112112
From a user perspective these values should always be derived from
113113
the active session, but to allow a simple + extensible testing
114114
infrastructure in the future, any of these can be passed in using
115115
the name of the field in session.py.
116116
117117
Returns:
118-
str: The SQL string corresponding to the unqualified query.
118+
The SQL string corresponding to the unqualified query.
119119
"""
120120
graph: GraphMetadata
121121
config: PyDoughConfigs
@@ -139,15 +139,15 @@ def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame:
139139
DataFrame.
140140
141141
Args:
142-
node (UnqualifiedNode): The node to convert to a DataFrame.
143-
**kwargs: Additional arguments to pass to the conversion for testing.
142+
`node`: The node to convert to a DataFrame.
143+
`**kwargs`: Additional arguments to pass to the conversion for testing.
144144
From a user perspective these values should always be derived from
145145
the active session, but to allow a simple + extensible testing
146146
infrastructure in the future, any of these can be passed in using
147147
the name of the field in session.py.
148148
149149
Returns:
150-
pd.DataFrame: The DataFrame corresponding to the unqualified query.
150+
The DataFrame corresponding to the unqualified query.
151151
"""
152152
graph: GraphMetadata
153153
config: PyDoughConfigs

pydough/metadata/errors.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"PyDoughPredicate",
1818
"extract_array",
1919
"extract_bool",
20+
"extract_integer",
2021
"extract_object",
2122
"extract_string",
2223
"is_bool",
@@ -375,6 +376,31 @@ def extract_bool(json_obj: dict, key_name: str, obj_name: str) -> bool:
375376
return value
376377

377378

379+
def extract_integer(json_obj: dict, key_name: str, obj_name: str) -> int:
380+
"""
381+
Extracts an integer field from a JSON object, returning the integer field
382+
and verifying that the field exists and is well formed.
383+
384+
Args:
385+
`json_obj`: the JSON object to extract the string from.
386+
`key_name`: the name of the key in the JSON object that
387+
contains the string.
388+
`obj_name`: the name of the object being extracted from, to be used
389+
in error messages.
390+
391+
Returns:
392+
The integer value of the field.
393+
394+
Raises:
395+
`PyDoughMetadataException` if the JSON object does not contain a key
396+
with the name `key_name`, or if the value of the key is not an integer.
397+
"""
398+
HasPropertyWith(key_name, is_integer).verify(json_obj, obj_name)
399+
value = json_obj[key_name]
400+
assert isinstance(value, int)
401+
return value
402+
403+
378404
def extract_array(json_obj: dict, key_name: str, obj_name: str) -> list:
379405
"""
380406
Extracts an array field from a JSON object, returning the string field

pydough/metadata/graphs/graph_metadata.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@
22
Definition of PyDough metadata for a graph.
33
"""
44

5+
from typing import TYPE_CHECKING
6+
57
from pydough.metadata.abstract_metadata import AbstractMetadata
68
from pydough.metadata.errors import HasType, PyDoughMetadataException, is_valid_name
79

10+
if TYPE_CHECKING:
11+
from pydough.pydough_operators import (
12+
ExpressionFunctionOperator,
13+
)
14+
815

916
class GraphMetadata(AbstractMetadata):
1017
"""
@@ -17,6 +24,7 @@ class GraphMetadata(AbstractMetadata):
1724
"version",
1825
"collections",
1926
"relationships",
27+
"functions",
2028
"additional definitions",
2129
"verified pydough analysis",
2230
"extra semantic info",
@@ -39,6 +47,7 @@ def __init__(
3947
self._verified_pydough_analysis: list[dict] | None = verified_pydough_analysis
4048
self._name: str = name
4149
self._collections: dict[str, AbstractMetadata] = {}
50+
self._functions: dict[str, ExpressionFunctionOperator] = {}
4251
super().__init__(description, synonyms, extra_semantic_info)
4352

4453
@property
@@ -55,6 +64,13 @@ def collections(self) -> dict[str, AbstractMetadata]:
5564
"""
5665
return self._collections
5766

67+
@property
68+
def functions(self) -> dict[str, "ExpressionFunctionOperator"]:
69+
"""
70+
The user defined functions contained within the graph.
71+
"""
72+
return self._functions
73+
5874
@property
5975
def error_name(self) -> str:
6076
return f"graph {self.name!r}"
@@ -131,3 +147,46 @@ def get_collection(self, collection_name: str) -> AbstractMetadata:
131147

132148
def __getitem__(self, key: str):
133149
return self.get_collection(key)
150+
151+
def get_function_names(self) -> list[str]:
152+
"""
153+
Fetches all of the names of user defined functions in the graph.
154+
"""
155+
return list(self.functions)
156+
157+
def get_function(self, function_name: str) -> "ExpressionFunctionOperator":
158+
"""
159+
Fetches a specific function's metadata from within the graph by name.
160+
"""
161+
if function_name not in self.functions:
162+
raise PyDoughMetadataException(
163+
f"{self.error_name} does not have a function named {function_name!r}"
164+
)
165+
return self.functions[function_name]
166+
167+
def add_function(self, name: str, function: "ExpressionFunctionOperator") -> None:
168+
"""
169+
Adds a new user defined function to the graph.
170+
171+
Args:
172+
`name`: the name of the function.
173+
`function`: the function operator being inserted into the graph.
174+
175+
Raises:
176+
`PyDoughMetadataException`: if `function` cannot be inserted
177+
into the graph because of a name collision.
178+
"""
179+
is_valid_name.verify(name, "function name")
180+
if name == self.name:
181+
raise PyDoughMetadataException(
182+
f"Function name {name!r} cannot be the same as the graph name {self.name!r}"
183+
)
184+
if name in self.get_collection_names():
185+
raise PyDoughMetadataException(
186+
f"Function name {name!r} cannot be the same as a collection name in {self.error_name}"
187+
)
188+
if name in self.functions:
189+
raise PyDoughMetadataException(
190+
f"Function {name!r} already exists in {self.error_name}"
191+
)
192+
self.functions[name] = function

0 commit comments

Comments
 (0)