From 34eb99fc78eccf877fdc89a61189f6ba1a82dfb4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 15 Sep 2025 12:54:46 -0400 Subject: [PATCH 1/3] Adding sesssion kwarg option to APIs, and replacing daisy-chaining of graph/config with the session in most places --- documentation/usage.md | 12 ++++ pydough/conversion/agg_split.py | 22 +++--- pydough/conversion/filter_pushdown.py | 10 +-- pydough/conversion/hybrid_translator.py | 12 ++-- pydough/conversion/relational_converter.py | 26 ++++--- .../conversion/relational_simplification.py | 36 +++++----- pydough/evaluation/evaluate_unqualified.py | 68 +++++++++++++------ pydough/exploration/explain.py | 24 ++++--- pydough/exploration/structure.py | 10 +-- pydough/exploration/term.py | 18 +++-- pydough/sqlglot/execute_relational.py | 28 ++++---- .../sqlglot_relational_expression_visitor.py | 12 ++-- pydough/sqlglot/sqlglot_relational_visitor.py | 16 ++--- pydough/unqualified/qualification.py | 38 +++++------ tests/README.md | 2 +- tests/conftest.py | 39 +++++++++-- tests/test_exploration.py | 11 ++- tests/test_logging.py | 15 +--- tests/test_metadata_errors.py | 7 +- tests/test_qdag_conversion.py | 14 ++-- tests/test_qualification.py | 28 ++++---- tests/test_qualification_errors.py | 16 ++--- tests/test_relational_execution.py | 10 +-- tests/test_relational_execution_tpch.py | 7 +- tests/test_relational_nodes_to_sqlglot.py | 9 +-- tests/test_relational_to_sql.py | 14 ++-- tests/testing_utilities.py | 11 +-- 27 files changed, 280 insertions(+), 235 deletions(-) diff --git a/documentation/usage.md b/documentation/usage.md index 00d231438..4387fb79b 100644 --- a/documentation/usage.md +++ b/documentation/usage.md @@ -425,6 +425,7 @@ The `to_sql` API takes in PyDough code and transforms it into SQL query text wit - `metadata`: the PyDough knowledge graph to use for the conversion (if omitted, `pydough.active_session.metadata` is used instead). - `config`: the PyDough configuration settings to use for the conversion (if omitted, `pydough.active_session.config` is used instead). - `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation. +- `session`: a PyDough session object which, if provided, is used instead of `pydough.active_session` or the `metadata` / `config` / `database` arguments. Note: this argument cannot be used alongside those arguments. Below is an example of using `pydough.to_sql` and the output (the SQL output may be outdated if PyDough's SQL conversion process has been updated): @@ -435,6 +436,16 @@ result = european_countries.CALCULATE(name, n_custs=COUNT(customers)) pydough.to_sql(result, columns=["name", "n_custs"]) ``` +""" +s = ''' +european_countries = nations.WHERE(region.name == "EUROPE") +result = european_countries.CALCULATE(name, n_custs=COUNT(customers)) +''' +pydough.active_session.load_metadata_graph("tests/test_metadata/sample_graphs.json", "TPCH") +u = pydough.from_string(s) +print(pydough.to_sql(result, columns=["name", "n_custs"])) +""" + ```sql SELECT name, COALESCE(agg_0, 0) AS n_custs FROM ( @@ -478,6 +489,7 @@ The `to_df` API does all the same steps as the [`to_sql` API](#pydoughto_sql), b - `metadata`: the PyDough knowledge graph to use for the conversion (if omitted, `pydough.active_session.metadata` is used instead). - `config`: the PyDough configuration settings to use for the conversion (if omitted, `pydough.active_session.config` is used instead). - `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation. +- `session`: a PyDough session object which, if provided, is used instead of `pydough.active_session` or the `metadata` / `config` / `database` arguments. Note: this argument cannot be used alongside those arguments. - `display_sql`: displays the sql before executing in a logger. Below is an example of using `pydough.to_df` and the output, attached to a sqlite database containing data for the TPC-H schema: diff --git a/pydough/conversion/agg_split.py b/pydough/conversion/agg_split.py index dad089a7c..57f715e81 100644 --- a/pydough/conversion/agg_split.py +++ b/pydough/conversion/agg_split.py @@ -7,7 +7,7 @@ import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.relational import ( Aggregate, CallExpression, @@ -51,7 +51,7 @@ """ -def decompose_aggregations(node: Aggregate, config: PyDoughConfigs) -> RelationalNode: +def decompose_aggregations(node: Aggregate, session: PyDoughSession) -> RelationalNode: """ Splits up an aggregate node into an aggregate followed by a projection when the aggregate contains 1+ calls to functions that can be split into 1+ @@ -59,7 +59,7 @@ def decompose_aggregations(node: Aggregate, config: PyDoughConfigs) -> Relationa Args: `node`: the aggregate node to be decomposed. - `config`: the current configuration settings. + `session`: the PyDough session used during the transformation. Returns: The projection node on top of the new aggregate, overall containing the @@ -110,7 +110,7 @@ def decompose_aggregations(node: Aggregate, config: PyDoughConfigs) -> Relationa ) # If the config specifies that the default value for AVG should be # zero, wrap the division in a DEFAULT_TO call. - if config.avg_default_zero: + if session.config.avg_default_zero: avg_call = CallExpression( pydop.DEFAULT_TO, agg.data_type, @@ -277,7 +277,7 @@ def transpose_aggregate_join( def attempt_join_aggregate_transpose( - node: Aggregate, join: Join, config: PyDoughConfigs + node: Aggregate, join: Join, session: PyDoughSession ) -> tuple[RelationalNode, bool]: """ Determine whether the aggregate join transpose operation can occur, and if @@ -396,7 +396,7 @@ def attempt_join_aggregate_transpose( for col in node.aggregations.values(): if col.op in decomposable_aggfuncs: return split_partial_aggregates( - decompose_aggregations(node, config), config + decompose_aggregations(node, session), session ), False # Keep a dictionary for the projection columns that will be used to post-process @@ -464,7 +464,7 @@ def attempt_join_aggregate_transpose( def split_partial_aggregates( - node: RelationalNode, config: PyDoughConfigs + node: RelationalNode, session: PyDoughSession ) -> RelationalNode: """ Splits partial aggregates above joins into two aggregates, one above the @@ -473,7 +473,7 @@ def split_partial_aggregates( Args: `node`: the root node of the relational plan to be transformed. - `config`: the current configuration settings. + `session`: the PyDough session used during the transformation. Returns: The transformed node. The transformation is also done-in-place. @@ -481,11 +481,13 @@ def split_partial_aggregates( # If the aggregate+join pattern is detected, attempt to do the transpose. handle_inputs: bool = True if isinstance(node, Aggregate) and isinstance(node.input, Join): - node, handle_inputs = attempt_join_aggregate_transpose(node, node.input, config) + node, handle_inputs = attempt_join_aggregate_transpose( + node, node.input, session + ) # If needed, recursively invoke the procedure on all inputs to the node. if handle_inputs: node = node.copy( - inputs=[split_partial_aggregates(input, config) for input in node.inputs] + inputs=[split_partial_aggregates(input, session) for input in node.inputs] ) return node diff --git a/pydough/conversion/filter_pushdown.py b/pydough/conversion/filter_pushdown.py index 868fa9d59..7e00f7aa8 100644 --- a/pydough/conversion/filter_pushdown.py +++ b/pydough/conversion/filter_pushdown.py @@ -6,7 +6,7 @@ import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.relational import ( Aggregate, CallExpression, @@ -66,7 +66,7 @@ class FilterPushdownShuttle(RelationalShuttle): cannot be pushed further. """ - def __init__(self, configs: PyDoughConfigs): + def __init__(self, session: PyDoughSession): # The set of filters that are currently being pushed down. When # visit_xxx is called, it is presumed that the set of conditions in # self.filters are the conditions that can be pushed down as far as the @@ -76,7 +76,7 @@ def __init__(self, configs: PyDoughConfigs): # simplification logic to aid in advanced filter predicate inference, # such as determining that a left join is redundant because if the RHS # column is null then the filter will always be false. - self.simplifier: SimplificationShuttle = SimplificationShuttle(configs) + self.simplifier: SimplificationShuttle = SimplificationShuttle(session) def reset(self): self.filters = set() @@ -300,7 +300,7 @@ def visit_empty_singleton(self, empty_singleton: EmptySingleton) -> RelationalNo return self.flush_remaining_filters(empty_singleton, self.filters, set()) -def push_filters(node: RelationalNode, configs: PyDoughConfigs) -> RelationalNode: +def push_filters(node: RelationalNode, session: PyDoughSession) -> RelationalNode: """ Transpose filter conditions down as far as possible. @@ -314,5 +314,5 @@ def push_filters(node: RelationalNode, configs: PyDoughConfigs) -> RelationalNod the node or into one of its inputs, or possibly both if there are multiple filters. """ - pusher: FilterPushdownShuttle = FilterPushdownShuttle(configs) + pusher: FilterPushdownShuttle = FilterPushdownShuttle(session) return node.accept_shuttle(pusher) diff --git a/pydough/conversion/hybrid_translator.py b/pydough/conversion/hybrid_translator.py index 3f4df1066..2d6670a97 100644 --- a/pydough/conversion/hybrid_translator.py +++ b/pydough/conversion/hybrid_translator.py @@ -7,7 +7,7 @@ from collections.abc import Iterable import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.database_connectors import DatabaseDialect from pydough.errors import PyDoughSQLException from pydough.metadata import ( @@ -80,8 +80,8 @@ class HybridTranslator: Class used to translate PyDough QDAG nodes into the HybridTree structure. """ - def __init__(self, configs: PyDoughConfigs, dialect: DatabaseDialect): - self.configs = configs + def __init__(self, session: PyDoughSession): + self.session = session # An index used for creating fake column names for aliases self.alias_counter: int = 0 # A stack where each element is a hybrid tree being derived @@ -91,7 +91,7 @@ def __init__(self, configs: PyDoughConfigs, dialect: DatabaseDialect): # If True, rewrites MEDIAN calls into an average of the 1-2 median rows # or rewrites QUANTILE calls to select the first qualifying row, # both derived from window functions, otherwise leaves as-is. - self.rewrite_median_quantile: bool = dialect not in { + self.rewrite_median_quantile: bool = session.database.dialect not in { DatabaseDialect.ANSI, DatabaseDialect.SNOWFLAKE, } @@ -481,8 +481,8 @@ def postprocess_agg_output( # COUNT/NDISTINCT for left joins since the semantics of those functions # never allow returning NULL. if ( - (agg_call.operator == pydop.SUM and self.configs.sum_default_zero) - or (agg_call.operator == pydop.AVG and self.configs.avg_default_zero) + (agg_call.operator == pydop.SUM and self.session.config.sum_default_zero) + or (agg_call.operator == pydop.AVG and self.session.config.avg_default_zero) or ( agg_call.operator in (pydop.COUNT, pydop.NDISTINCT) and joins_can_nullify diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 70fd64208..e232b408f 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -10,8 +10,7 @@ from dataclasses import dataclass import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs -from pydough.database_connectors import DatabaseDialect +from pydough.configs import PyDoughSession from pydough.metadata import ( CartesianProductMetadata, GeneralJoinMetadata, @@ -1434,7 +1433,7 @@ def confirm_root(node: RelationalNode) -> RelationalRoot: def optimize_relational_tree( root: RelationalRoot, - configs: PyDoughConfigs, + session: PyDoughSession, additional_shuttles: list[RelationalExpressionShuttle], ) -> RelationalRoot: """ @@ -1443,7 +1442,7 @@ def optimize_relational_tree( Args: `root`: the relational root to optimize. - `configs`: the configuration settings to use during optimization. + `configs`: PyDough session used during optimization. `additional_shuttles`: additional relational expression shuttles to use for expression simplification. @@ -1472,7 +1471,7 @@ def optimize_relational_tree( root = confirm_root(pullup_projections(root)) # Push filters down as far as possible - root = confirm_root(push_filters(root, configs)) + root = confirm_root(push_filters(root, session)) # Merge adjacent projections, unless it would result in excessive duplicate # subexpression computations. @@ -1480,7 +1479,7 @@ def optimize_relational_tree( # Split aggregations on top of joins so part of the aggregate happens # underneath the join. - root = confirm_root(split_partial_aggregates(root, configs)) + root = confirm_root(split_partial_aggregates(root, session)) # Delete aggregations that are inferred to be redundant due to operating on # already unique data. @@ -1511,8 +1510,8 @@ def optimize_relational_tree( # pullup and pushdown and so on. for _ in range(2): root = confirm_root(pullup_projections(root)) - simplify_expressions(root, configs, additional_shuttles) - root = confirm_root(push_filters(root, configs)) + simplify_expressions(root, session, additional_shuttles) + root = confirm_root(push_filters(root, session)) root = pruner.prune_unused_columns(root) # Re-run projection merging, without pushing into joins. This will allow @@ -1534,8 +1533,7 @@ def optimize_relational_tree( def convert_ast_to_relational( node: PyDoughCollectionQDAG, columns: list[tuple[str, str]] | None, - configs: PyDoughConfigs, - dialect: DatabaseDialect = DatabaseDialect.ANSI, + session: PyDoughSession, ) -> RelationalRoot: """ Main API for converting from the collection QDAG form into relational @@ -1547,8 +1545,8 @@ def convert_ast_to_relational( describing every column that should be in the output, in the order they should appear, and the alias they should be given. If None, uses the most recent CALCULATE in the node to determine the columns. - `configs`: the configuration settings to use during translation. - `dialect`: the database dialect being used. + `session`: the PyDough session used to fetch configuration settings + and SQL dialect information. Returns: The RelationalRoot for the entire PyDough calculation that the @@ -1563,7 +1561,7 @@ def convert_ast_to_relational( # Convert the QDAG node to a hybrid tree, including any necessary # transformations such as de-correlation. - hybrid_translator: HybridTranslator = HybridTranslator(configs, dialect) + hybrid_translator: HybridTranslator = HybridTranslator(session) hybrid: HybridTree = hybrid_translator.convert_qdag_to_hybrid(node) # Then, invoke relational conversion procedure. The first element in the @@ -1579,7 +1577,7 @@ def convert_ast_to_relational( # Invoke the optimization procedures on the result to clean up the tree. additional_shuttles: list[RelationalExpressionShuttle] = [] optimized_result: RelationalRoot = optimize_relational_tree( - raw_result, configs, additional_shuttles + raw_result, session, additional_shuttles ) return optimized_result diff --git a/pydough/conversion/relational_simplification.py b/pydough/conversion/relational_simplification.py index 99a2e10d9..ec5eb3f8a 100644 --- a/pydough/conversion/relational_simplification.py +++ b/pydough/conversion/relational_simplification.py @@ -16,7 +16,7 @@ import pandas as pd import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.relational import ( Aggregate, CallExpression, @@ -212,11 +212,11 @@ class SimplificationShuttle(RelationalExpressionShuttle): simplifying their inputs and placing their predicate sets on the stack. """ - def __init__(self, configs: PyDoughConfigs): + def __init__(self, session: PyDoughSession): self.stack: list[PredicateSet] = [] self._input_predicates: dict[RelationalExpression, PredicateSet] = {} self._no_group_aggregate: bool = False - self._configs: PyDoughConfigs = configs + self._session: PyDoughSession = session @property def input_predicates(self) -> dict[RelationalExpression, PredicateSet]: @@ -247,18 +247,18 @@ def no_group_aggregate(self, value: bool) -> None: self._no_group_aggregate = value @property - def configs(self) -> PyDoughConfigs: + def session(self) -> PyDoughSession: """ - Returns the PyDough configuration settings. + Returns the PyDough session used by the simplifier. """ - return self._configs + return self._session - @configs.setter - def configs(self, value: PyDoughConfigs) -> None: + @session.setter + def session(self, value: PyDoughSession) -> None: """ - Sets the PyDough configuration settings. + Sets the PyDough session used by the simplifier. """ - self._configs = value + self._session = value def reset(self) -> None: self.stack = [] @@ -659,9 +659,9 @@ def simplify_datetime_literal_part( # Derive the day of week as an integer, adjusting based on the # configured start of the week. dow: int = timestamp_value.weekday() - dow -= self.configs.start_of_week.pandas_dow + dow -= self.session.config.start_of_week.pandas_dow dow %= 7 - if not self.configs.start_week_as_zero: + if not self.session.config.start_week_as_zero: dow += 1 return LiteralExpression(dow, NumericType()) case _: @@ -740,7 +740,7 @@ def compress_datetime_literal_chain( # (accounting for the session configs) and subtract that # many days from the normalized timestamp. dow: int = timestamp_value.weekday() - dow -= self.configs.start_of_week.pandas_dow + dow -= self.session.config.start_of_week.pandas_dow dow %= 7 timestamp_value = timestamp_value.normalize() - pd.Timedelta( days=dow @@ -1423,11 +1423,11 @@ class SimplificationVisitor(RelationalVisitor): def __init__( self, - configs: PyDoughConfigs, + session: PyDoughSession, additional_shuttles: list[RelationalExpressionShuttle], ): self.stack: list[dict[RelationalExpression, PredicateSet]] = [] - self.shuttle: SimplificationShuttle = SimplificationShuttle(configs) + self.shuttle: SimplificationShuttle = SimplificationShuttle(session) self.additional_shuttles: list[RelationalExpressionShuttle] = ( additional_shuttles ) @@ -1660,7 +1660,7 @@ def visit_aggregate(self, node: Aggregate) -> None: def simplify_expressions( node: RelationalNode, - configs: PyDoughConfigs, + session: PyDoughSession, additional_shuttles: list[RelationalExpressionShuttle], ) -> None: """ @@ -1669,13 +1669,13 @@ def simplify_expressions( Args: `node`: The relational node to perform simplification on. - `configs`: The PyDough configuration settings. + `session`: The PyDough session used during the simplification. `additional_shuttles`: A list of additional shuttles to apply to the expressions of the node and its descendants. These shuttles are applied after the simplification shuttle, and can be used to perform additional transformations on the expressions. """ simplifier: SimplificationVisitor = SimplificationVisitor( - configs, additional_shuttles + session, additional_shuttles ) node.accept(simplifier) diff --git a/pydough/evaluation/evaluate_unqualified.py b/pydough/evaluation/evaluate_unqualified.py index d7693941b..2ab99bffb 100644 --- a/pydough/evaluation/evaluate_unqualified.py +++ b/pydough/evaluation/evaluate_unqualified.py @@ -9,7 +9,7 @@ import pandas as pd import pydough -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughConfigs, PyDoughSession from pydough.conversion import convert_ast_to_relational from pydough.database_connectors import DatabaseContext from pydough.errors import ( @@ -27,12 +27,14 @@ __all__ = ["to_df", "to_sql"] -def _load_session_info( - **kwargs, -) -> tuple[GraphMetadata, PyDoughConfigs, DatabaseContext]: +def _load_session_info(**kwargs) -> PyDoughSession: """ Load the session information from the active session unless it is found - in the keyword arguments. + in the keyword arguments. The following variants are accepted: + - If `session` is found, it is used directly. + - If `metadata`, `config` and/or `database` are found, they are used to + construct a new session. + - If none of these are found, the active session is used. Args: **kwargs: The keyword arguments to load the session information from. @@ -40,6 +42,32 @@ def _load_session_info( Returns: The metadata graph, configuration settings and Database context. """ + + # If there are no keyword arguments, return the active session. + if len(kwargs) == 0: + return pydough.active_session + + # If the session is provided, use it directly. Verify it has a metadata + # graph attached, and there are no other keyword arguments. + if "session" in kwargs: + session = kwargs.pop("session") + if not isinstance(session, PyDoughSession): + raise PyDoughSessionException( + f"Expected `session` to be a PyDoughSession, got {session.__class__.__name__}." + ) + if session.metadata is None: + raise PyDoughSessionException( + "Cannot evaluate Pydough without a metadata graph. " + "Please use `session.load_metadata_graph` to attach a graph to the session." + ) + if kwargs: + raise ValueError(f"Unexpected keyword arguments: {kwargs}") + return session + + # Otherwise, load the individual components and construct a session. + # If any of the components are missing, use the active session's value. The + # metadata graph is required, so if it is missing from both the keyword + # arguments and the active session, raise an error. metadata: GraphMetadata if "metadata" in kwargs: metadata = kwargs.pop("metadata") @@ -61,7 +89,13 @@ def _load_session_info( else: database = pydough.active_session.database assert not kwargs, f"Unexpected keyword arguments: {kwargs}" - return metadata, config, database + + # Construct the new session + new_session: PyDoughSession = PyDoughSession() + new_session._metadata = metadata + new_session._config = config + new_session._database = database + return new_session def _load_column_selection(kwargs: dict[str, object]) -> list[tuple[str, str]] | None: @@ -113,18 +147,15 @@ def to_sql(node: UnqualifiedNode, **kwargs) -> str: Returns: The SQL string corresponding to the unqualified query. """ - graph: GraphMetadata - config: PyDoughConfigs - database: DatabaseContext column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs) - graph, config, database = _load_session_info(**kwargs) - qualified: PyDoughQDAG = qualify_node(node, graph, config) + session: PyDoughSession = _load_session_info(**kwargs) + qualified: PyDoughQDAG = qualify_node(node, session) if not isinstance(qualified, PyDoughCollectionQDAG): raise pydough.active_session.error_builder.expected_collection(qualified) relational: RelationalRoot = convert_ast_to_relational( - qualified, column_selection, config, database.dialect + qualified, column_selection, session ) - return convert_relation_to_sql(relational, database.dialect, config) + return convert_relation_to_sql(relational, session) def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame: @@ -143,16 +174,13 @@ def to_df(node: UnqualifiedNode, **kwargs) -> pd.DataFrame: Returns: The DataFrame corresponding to the unqualified query. """ - graph: GraphMetadata - config: PyDoughConfigs - database: DatabaseContext column_selection: list[tuple[str, str]] | None = _load_column_selection(kwargs) display_sql: bool = bool(kwargs.pop("display_sql", False)) - graph, config, database = _load_session_info(**kwargs) - qualified: PyDoughQDAG = qualify_node(node, graph, config) + session: PyDoughSession = _load_session_info(**kwargs) + qualified: PyDoughQDAG = qualify_node(node, session) if not isinstance(qualified, PyDoughCollectionQDAG): raise pydough.active_session.error_builder.expected_collection(qualified) relational: RelationalRoot = convert_ast_to_relational( - qualified, column_selection, config, database.dialect + qualified, column_selection, session ) - return execute_df(relational, database, config, display_sql) + return execute_df(relational, session, display_sql) diff --git a/pydough/exploration/explain.py b/pydough/exploration/explain.py index 045bb1f02..2fc29c301 100644 --- a/pydough/exploration/explain.py +++ b/pydough/exploration/explain.py @@ -7,7 +7,7 @@ import pydough import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.errors import PyDoughQDAGException from pydough.metadata.abstract_metadata import AbstractMetadata from pydough.metadata.collections import CollectionMetadata, SimpleTableMetadata @@ -249,7 +249,9 @@ def explain_graph(graph: GraphMetadata, verbose: bool) -> str: return "\n".join(lines) -def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str: +def explain_unqualified( + node: UnqualifiedNode, session: PyDoughSession, verbose: bool +) -> str: """ Displays information about an unqualified node, if it is possible to qualify the node as a collection. If not, then `explain_term` may need to @@ -262,6 +264,8 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str: Args: `node`: the unqualified node object being examined. + `session`: the session to use for the explanation. If not provided, + the active session will be used. `verbose`: if true, displays more detailed information about `node` and in a less compact format. @@ -270,13 +274,13 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str: """ lines: list[str] = [] qualified_node: PyDoughQDAG | None = None - config: PyDoughConfigs = pydough.active_session.config + session = pydough.active_session if session is None else session # Attempt to qualify the node, dumping an appropriate message if it could # not be qualified try: root: UnqualifiedRoot | None = find_unqualified_root(node) if root is not None: - qualified_node = qualify_node(node, root._parcel[0], config) + qualified_node = qualify_node(node, session) else: # If the root is None, it means that the node was an expression # without information about its context. @@ -497,7 +501,7 @@ def explain_unqualified(node: UnqualifiedNode, verbose: bool) -> str: def explain( data: AbstractMetadata | UnqualifiedNode, verbose: bool = False, - config: PyDoughConfigs | None = None, + session: PyDoughSession | None = None, ) -> str: """ Displays information about a PyDough metadata object or unqualified node. @@ -509,14 +513,14 @@ def explain( `data`: the metadata or unqualified node object being examined. `verbose`: if true, displays more detailed information about `data` and in a less compact format. - `config`: the configuration to use for the explanation. If not provided, - the active session's configuration will be used. + `session`: the session to use for the explanation. If not provided, + the active session will be used. Returns: An explanation of `data`. """ - if config is None: - config = pydough.active_session.config + if session is None: + session = pydough.active_session match data: case GraphMetadata(): return explain_graph(data, verbose) @@ -525,7 +529,7 @@ def explain( case PropertyMetadata(): return explain_property(data, verbose) case UnqualifiedNode(): - return explain_unqualified(data, verbose) + return explain_unqualified(data, session, verbose) case _: raise NotImplementedError( f"Cannot call pydough.explain on argument of type {data.__class__.__name__}" diff --git a/pydough/exploration/structure.py b/pydough/exploration/structure.py index d43c94d0b..4fa9a8ac6 100644 --- a/pydough/exploration/structure.py +++ b/pydough/exploration/structure.py @@ -7,7 +7,7 @@ import pydough -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.metadata.collections import CollectionMetadata from pydough.metadata.graphs import GraphMetadata from pydough.metadata.properties import ( @@ -19,7 +19,7 @@ def explain_structure( - graph: GraphMetadata, config: PyDoughConfigs | None = None + graph: GraphMetadata, session: PyDoughSession | None = None ) -> str: """ Displays information about a PyDough metadata graph, including the @@ -35,14 +35,14 @@ def explain_structure( Args: `graph`: the metadata graph being examined. `config`: the configuration to use for the explanation. If not provided, - the active session's configuration will be used. + the active session will be used. Returns: The string representation of the graph's structure. """ assert isinstance(graph, GraphMetadata) - if config is None: - config = pydough.active_session.config + if session is None: + session = pydough.active_session lines: list[str] = [] lines.append(f"Structure of PyDough graph: {graph.name}") collection_names: list[str] = sorted(graph.get_collection_names()) diff --git a/pydough/exploration/term.py b/pydough/exploration/term.py index 4f6e62c20..6ab9734bb 100644 --- a/pydough/exploration/term.py +++ b/pydough/exploration/term.py @@ -9,7 +9,7 @@ import pydough import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.errors import PyDoughQDAGException from pydough.qdag import ( BackReferenceExpression, @@ -97,7 +97,7 @@ def explain_term( node: UnqualifiedNode, term: UnqualifiedNode, verbose: bool = False, - config: PyDoughConfigs | None = None, + session: PyDoughSession | None = None, ) -> str: """ Displays information about an unqualified node as it exists within @@ -120,8 +120,8 @@ def explain_term( `node`. This term could be an expression or a collection. `verbose`: if true, displays more detailed information about `node` and `term` in a less compact format. - `config`: the configuration to use for the explanation. If not provided, - the active session's configuration will be used. + `config`: the PyDough session used for the explanation. If not provided, + the active session will be used. Returns: An explanation of `term` as it exists within the context of `node`. @@ -130,15 +130,15 @@ def explain_term( lines: list[str] = [] root: UnqualifiedRoot | None = find_unqualified_root(node) qualified_node: PyDoughQDAG | None = None - if config is None: - config = pydough.active_session.config + if session is None: + session = pydough.active_session try: if root is None: lines.append( f"Invalid first argument to pydough.explain_term: {display_raw(node)}" ) else: - qualified_node = qualify_node(node, root._parcel[0], config) + qualified_node = qualify_node(node, session) except PyDoughQDAGException as e: if "Unrecognized term" in str(e): lines.append( @@ -158,9 +158,7 @@ def explain_term( lines.append(f" {qualified_node.to_string()}") elif qualified_node is not None and root is not None: assert isinstance(qualified_node, PyDoughCollectionQDAG) - new_children, qualified_term = qualify_term( - qualified_node, term, root._parcel[0] - ) + new_children, qualified_term = qualify_term(qualified_node, term, session) if verbose: lines.append("Collection:") for line in qualified_node.to_tree_string().splitlines(): diff --git a/pydough/sqlglot/execute_relational.py b/pydough/sqlglot/execute_relational.py index e389b75e1..9ea45fb22 100644 --- a/pydough/sqlglot/execute_relational.py +++ b/pydough/sqlglot/execute_relational.py @@ -27,9 +27,8 @@ from sqlglot.optimizer.scope import traverse_scope, walk_in_scope import pydough -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.database_connectors import ( - DatabaseContext, DatabaseDialect, ) from pydough.logger import get_logger @@ -48,25 +47,24 @@ __all__ = ["convert_relation_to_sql", "execute_df"] -def convert_relation_to_sql( - relational: RelationalRoot, - dialect: DatabaseDialect, - config: PyDoughConfigs, -) -> str: +def convert_relation_to_sql(relational: RelationalRoot, session: PyDoughSession) -> str: """ Convert the given relational tree to a SQL string using the given dialect. Args: `relational`: The relational tree to convert. - `dialect`: The dialect to use for the conversion. + `session`: The PyDough session encapsulating the logic used to execute + the logic, including the PyDough configs and the database context. Returns: The SQL string representing the relational tree. """ glot_expr: SQLGlotExpression = SQLGlotRelationalVisitor( - dialect, config + session ).relational_to_sqlglot(relational) - sqlglot_dialect: SQLGlotDialect = convert_dialect_to_sqlglot(dialect) + sqlglot_dialect: SQLGlotDialect = convert_dialect_to_sqlglot( + session.database.dialect + ) # Apply the SQLGlot optimizer to the AST. try: @@ -411,8 +409,7 @@ def convert_dialect_to_sqlglot(dialect: DatabaseDialect) -> SQLGlotDialect: def execute_df( relational: RelationalRoot, - ctx: DatabaseContext, - config: PyDoughConfigs, + session: PyDoughSession, display_sql: bool = False, ) -> pd.DataFrame: """ @@ -421,15 +418,16 @@ def execute_df( Args: `relational`: The relational tree to execute. - `ctx`: The database context to execute the query in. + `session`: The PyDough session encapsulating the logic used to execute + the logic, including the database context. `display_sql`: if True, prints out the SQL that will be run before it is executed. Returns: The result of the query as a Pandas DataFrame """ - sql: str = convert_relation_to_sql(relational, ctx.dialect, config) + sql: str = convert_relation_to_sql(relational, session) if display_sql: pyd_logger = get_logger(__name__) pyd_logger.info(f"SQL query:\n {sql}") - return ctx.connection.execute_query_df(sql) + return session._database.connection.execute_query_df(sql) diff --git a/pydough/sqlglot/sqlglot_relational_expression_visitor.py b/pydough/sqlglot/sqlglot_relational_expression_visitor.py index 2aa9b85dd..dacb06642 100644 --- a/pydough/sqlglot/sqlglot_relational_expression_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_expression_visitor.py @@ -18,7 +18,6 @@ import pydough import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs from pydough.database_connectors import DatabaseDialect from pydough.errors import PyDoughSQLException from pydough.relational import ( @@ -47,20 +46,19 @@ class SQLGlotRelationalExpressionVisitor(RelationalExpressionVisitor): def __init__( self, - dialect: DatabaseDialect, - correlated_names: dict[str, str], - config: PyDoughConfigs, relational_visitor: "SQLGlotRelationalVisitor", + correlated_names: dict[str, str], ) -> None: # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[SQLGlotExpression] = [] - self._dialect: DatabaseDialect = dialect + self._dialect: DatabaseDialect = relational_visitor._session.database.dialect self._correlated_names: dict[str, str] = correlated_names - self._config: PyDoughConfigs = config self._relational_visitor: SQLGlotRelationalVisitor = relational_visitor self._bindings: BaseTransformBindings = bindings_from_dialect( - dialect, config, self._relational_visitor + relational_visitor._session.database.dialect, + relational_visitor._session.config, + self._relational_visitor, ) def reset(self) -> None: diff --git a/pydough/sqlglot/sqlglot_relational_visitor.py b/pydough/sqlglot/sqlglot_relational_visitor.py index fe7977de7..7555021d8 100644 --- a/pydough/sqlglot/sqlglot_relational_visitor.py +++ b/pydough/sqlglot/sqlglot_relational_visitor.py @@ -15,7 +15,7 @@ from sqlglot.expressions import Star as SQLGlotStar from sqlglot.expressions import convert as sqlglot_convert -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.database_connectors import DatabaseDialect from pydough.relational import ( Aggregate, @@ -71,20 +71,14 @@ class SQLGlotRelationalVisitor(RelationalVisitor): the relational tree 1 node at a time. """ - def __init__( - self, - dialect: DatabaseDialect, - config: PyDoughConfigs, - ) -> None: + def __init__(self, session: PyDoughSession) -> None: # Keep a stack of SQLGlot expressions so we can build up # intermediate results. self._stack: list[Select] = [] - self._dialect: DatabaseDialect = dialect + self._session: PyDoughSession = session self._correlated_names: dict[str, str] = {} self._expr_visitor: SQLGlotRelationalExpressionVisitor = ( - SQLGlotRelationalExpressionVisitor( - dialect, self._correlated_names, config, self - ) + SQLGlotRelationalExpressionVisitor(self, self._correlated_names) ) self._alias_modifier: ColumnReferenceInputNameModifier = ( ColumnReferenceInputNameModifier() @@ -270,7 +264,7 @@ def _convert_ordering( ) # Ignore non-default na first/last positions for SQLite dialect na_first: bool - if self._dialect == DatabaseDialect.SQLITE: + if self._session.database.dialect == DatabaseDialect.SQLITE: if col.ascending: if not col.nulls_first: warnings.warn( diff --git a/pydough/unqualified/qualification.py b/pydough/unqualified/qualification.py index cbe669b25..e8e8b2f45 100644 --- a/pydough/unqualified/qualification.py +++ b/pydough/unqualified/qualification.py @@ -9,7 +9,7 @@ import pydough import pydough.pydough_operators as pydop -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.errors import PyDoughUnqualifiedException from pydough.metadata import GeneralJoinMetadata, GraphMetadata from pydough.pydough_operators.expression_operators import ( @@ -59,10 +59,10 @@ class Qualifier: - def __init__(self, graph: GraphMetadata, configs: PyDoughConfigs): - self._graph: GraphMetadata = graph - self._configs: PyDoughConfigs = configs - self._builder: AstNodeBuilder = AstNodeBuilder(graph) + def __init__(self, session: PyDoughSession): + self._session: PyDoughSession = session + assert session.metadata is not None + self._builder: AstNodeBuilder = AstNodeBuilder(session.metadata) @property def graph(self) -> GraphMetadata: @@ -70,7 +70,8 @@ def graph(self) -> GraphMetadata: The metadata for the PyDough graph in which is used to identify collections and properties. """ - return self._graph + assert self._session.metadata is not None + return self._session.metadata @property def builder(self) -> AstNodeBuilder: @@ -744,8 +745,8 @@ def _expressions_to_collations( Returns: The modified list of collation terms. """ - is_collation_propagated: bool = self._configs.propagate_collation - is_prev_asc: bool = self._configs.collation_default_asc + is_collation_propagated: bool = self._session.config.propagate_collation + is_prev_asc: bool = self._session.config.collation_default_asc modified_terms: list[UnqualifiedNode] = [] for idx, term in enumerate(terms): if isinstance(term, UnqualifiedCollation): @@ -1330,16 +1331,15 @@ def qualify_node( return answer -def qualify_node( - unqualified: UnqualifiedNode, graph: GraphMetadata, configs: PyDoughConfigs -) -> PyDoughQDAG: +def qualify_node(unqualified: UnqualifiedNode, session: PyDoughSession) -> PyDoughQDAG: """ Transforms an UnqualifiedNode into a qualified node. Args: `unqualified`: the UnqualifiedNode instance to be transformed. - `graph`: the metadata for the graph that the PyDough computations - are occurring within. + `session`: the session whose information should be used to derive + necessary information for the qualification, such as the graph and + configurations. Returns: The PyDough QDAG object for the qualified node. The result can be either @@ -1350,14 +1350,14 @@ def qualify_node( goes wrong during the qualification process, e.g. a term cannot be qualified or is not recognized. """ - qual: Qualifier = Qualifier(graph, configs) + qual: Qualifier = Qualifier(session) return qual.qualify_node( unqualified, qual.builder.build_global_context(), [], False ) def qualify_term( - collection: PyDoughCollectionQDAG, term: UnqualifiedNode, graph: GraphMetadata + collection: PyDoughCollectionQDAG, term: UnqualifiedNode, session: PyDoughSession ) -> tuple[list[PyDoughCollectionQDAG], PyDoughQDAG]: """ Transforms an UnqualifiedNode into a qualified node within the context of @@ -1369,8 +1369,9 @@ def qualify_term( context in which the term is being qualified. `term`: the UnqualifiedNode instance to be transformed into a qualified node within the context of `collection`. - `graph`: the metadata for the graph that the PyDough computations - are occurring within. + `session`: the session whose information should be used to derive + necessary information for the qualification, such as the graph and + configurations. Returns: A tuple where the second entry is the PyDough QDAG object for the @@ -1383,7 +1384,6 @@ def qualify_term( goes wrong during the qualification process, e.g. a term cannot be qualified or is not recognized. """ - configs: PyDoughConfigs = pydough.active_session.config - qual: Qualifier = Qualifier(graph, configs) + qual: Qualifier = Qualifier(session) children: list[PyDoughCollectionQDAG] = [] return children, qual.qualify_node(term, collection, children, True) diff --git a/tests/README.md b/tests/README.md index 78cea6eab..d21784b2d 100644 --- a/tests/README.md +++ b/tests/README.md @@ -62,7 +62,7 @@ The testing module uses `pytest` for running tests. The `conftest.py` file defin - `binary_operators`: Returns every PyDough expression operator for a BinOp. - `sqlite_dialects`: Returns the SQLite dialect. - `sqlite_people_jobs`: Returns a SQLite database connection with the PEOPLE and JOBS tables. -- `sqlite_people_jobs_context`: Returns a DatabaseContext for the SQLite PEOPLE and JOBS tables. +- `sqlite_people_jobs_session`: Returns a PyDough session containing a database for the SQLite PEOPLE and JOBS tables. - `sqlite_tpch_db_path`: Path to the TPCH database. - `sqlite_tpch_db`: Returns a connection to the SQLite TPCH database. By default it assumes tpch.db to be present in the root directory of PyDough. - `sqlite_tpch_db_context`: Returns a DatabaseContext for the SQLite TPCH database. diff --git a/tests/conftest.py b/tests/conftest.py index 8bc80c9bb..e858bf43a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ import pydough import pydough.pydough_operators as pydop -from pydough.configs import DayOfWeek, PyDoughConfigs +from pydough.configs import DayOfWeek, PyDoughConfigs, PyDoughSession from pydough.database_connectors import ( DatabaseConnection, DatabaseContext, @@ -111,6 +111,21 @@ def defog_config() -> PyDoughConfigs: return config +@pytest.fixture +def empty_sqlite_tpch_session( + sample_graph_path: str, default_config: PyDoughConfigs +) -> PyDoughSession: + """ + A PyDough session with an empty SQLite TPCH database connection and + the TPCH graph loaded. + """ + session: PyDoughSession = PyDoughSession() + session.load_metadata_graph(sample_graph_path, "TPCH") + session.config = default_config + session.database = DatabaseContext(empty_connection, DatabaseDialect.SQLITE) + return session + + @pytest.fixture( params=[ pytest.param((sow, swaz), id=f"{sow.name.lower()}-{'zero' if swaz else 'one'}") @@ -449,14 +464,16 @@ def sqlite_people_jobs() -> DatabaseConnection: @pytest.fixture -def sqlite_people_jobs_context( +def sqlite_people_jobs_session( sqlite_people_jobs: DatabaseConnection, sqlite_dialects: DatabaseDialect -) -> DatabaseContext: +) -> PyDoughSession: """ Returns a DatabaseContext for the SQLite PEOPLE and JOBS tables with the given dialect. """ - return DatabaseContext(sqlite_people_jobs, sqlite_dialects) + session: PyDoughSession = PyDoughSession() + session.database = DatabaseContext(sqlite_people_jobs, sqlite_dialects) + return session @pytest.fixture(scope="module") @@ -483,13 +500,25 @@ def sqlite_tpch_db(sqlite_tpch_db_path: str) -> sqlite3.Connection: @pytest.fixture -def sqlite_tpch_db_context(sqlite_tpch_db_path: str, sqlite_tpch_db) -> DatabaseContext: +def sqlite_tpch_db_context(sqlite_tpch_db) -> DatabaseContext: """ Return a DatabaseContext for the SQLite TPCH database. """ return DatabaseContext(DatabaseConnection(sqlite_tpch_db), DatabaseDialect.SQLITE) +@pytest.fixture +def sqlite_tpch_session( + empty_sqlite_tpch_session: PyDoughSession, sqlite_tpch_db_context: DatabaseContext +) -> PyDoughSession: + """ + Returns a variant of the `empty_sqlite_tpch_session` fixture, but with the + database context set to the actual TPCH database connection. + """ + empty_sqlite_tpch_session.database = sqlite_tpch_db_context + return empty_sqlite_tpch_session + + @pytest.fixture(scope="session") def defog_graphs() -> graph_fetcher: """ diff --git a/tests/test_exploration.py b/tests/test_exploration.py index 8aed0add5..97e3f32ce 100644 --- a/tests/test_exploration.py +++ b/tests/test_exploration.py @@ -7,6 +7,7 @@ import pytest import pydough +from pydough.configs import PyDoughSession from pydough.metadata import GraphMetadata from pydough.unqualified import UnqualifiedNode from tests.test_pydough_functions.exploration_examples import ( @@ -1317,6 +1318,7 @@ def test_unqualified_node_exploration( ], verbose: bool, get_sample_graph: graph_fetcher, + empty_sqlite_tpch_session: PyDoughSession, ) -> None: """ Verifies that `pydough.explain` called on unqualified nodes produces the @@ -1327,7 +1329,9 @@ def test_unqualified_node_exploration( ) graph: GraphMetadata = get_sample_graph(graph_name) node: UnqualifiedNode = pydough.init_pydough_context(graph)(test_impl)() - answer: str = pydough.explain(node, verbose=verbose) + answer: str = pydough.explain( + node, verbose=verbose, session=empty_sqlite_tpch_session + ) expected_answer: str = verbose_answer if verbose else non_verbose_answer assert answer == expected_answer, ( "Mismatch between produced string and expected answer" @@ -1875,6 +1879,7 @@ def test_unqualified_term_exploration( ], verbose: bool, get_sample_graph: graph_fetcher, + empty_sqlite_tpch_session: PyDoughSession, ) -> None: """ Verifies that `pydough.explain` called on unqualified nodes produces the @@ -1885,7 +1890,9 @@ def test_unqualified_term_exploration( ) graph: GraphMetadata = get_sample_graph(graph_name) node, term = pydough.init_pydough_context(graph)(test_impl)() - answer: str = pydough.explain_term(node, term, verbose=verbose) + answer: str = pydough.explain_term( + node, term, verbose=verbose, session=empty_sqlite_tpch_session + ) expected_answer: str = verbose_answer if verbose else non_verbose_answer assert answer == expected_answer, ( "Mismatch between produced string and expected answer" diff --git a/tests/test_logging.py b/tests/test_logging.py index eee6db1af..ddc90dcef 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -4,8 +4,7 @@ import pytest -from pydough.configs import PyDoughConfigs -from pydough.database_connectors import DatabaseContext +from pydough.configs import PyDoughSession from pydough.logger import get_logger from pydough.sqlglot import execute_df from tests.test_pydough_functions.tpch_relational_plans import ( @@ -159,10 +158,7 @@ def test_get_logger_invalid_env_level(monkeypatch): get_logger(name="logger_invalid_env_level_test_logger") -def test_execute_df_logging( - sqlite_tpch_db_context: DatabaseContext, - default_config: PyDoughConfigs, -) -> None: +def test_execute_df_logging(sqlite_tpch_session: PyDoughSession) -> None: """ Test the example TPC-H relational trees executed on a SQLite database, and capture any SQL or output printed to stdout. @@ -172,12 +168,7 @@ def test_execute_df_logging( output_capture = io.StringIO() # Redirect stdout to the buffer with redirect_stdout(output_capture): - execute_df( - root, - sqlite_tpch_db_context, - default_config, - display_sql=True, - ) + execute_df(root, sqlite_tpch_session, display_sql=True) # Retrieve the output from the buffer captured_output = output_capture.getvalue() required_op = """ diff --git a/tests/test_metadata_errors.py b/tests/test_metadata_errors.py index 85793a47f..c53c4fe8a 100644 --- a/tests/test_metadata_errors.py +++ b/tests/test_metadata_errors.py @@ -8,7 +8,7 @@ import pytest from pydough import parse_json_metadata_from_file -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.errors import PyDoughMetadataException, PyDoughTypeException from pydough.metadata import CollectionMetadata, GraphMetadata from pydough.unqualified import UnqualifiedNode, qualify_node, transform_code @@ -866,7 +866,6 @@ def test_invalid_general_join_conditions( invalid_graph_path: str, pydough_string: str, error_message: str, - default_config: PyDoughConfigs, ) -> None: with pytest.raises(Exception, match=re.escape(error_message)): graph: GraphMetadata = parse_json_metadata_from_file( @@ -879,4 +878,6 @@ def test_invalid_general_join_conditions( exec(pydough_string, {}, local_variables) pydough_code = local_variables["answer"] assert isinstance(pydough_code, UnqualifiedNode) - qualify_node(pydough_code, graph, default_config) + session: PyDoughSession = PyDoughSession() + session.metadata = graph + qualify_node(pydough_code, session) diff --git a/tests/test_qdag_conversion.py b/tests/test_qdag_conversion.py index 9c87f3e35..733549606 100644 --- a/tests/test_qdag_conversion.py +++ b/tests/test_qdag_conversion.py @@ -7,7 +7,7 @@ import pytest -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.conversion.relational_converter import convert_ast_to_relational from pydough.qdag import AstNodeBuilder, PyDoughCollectionQDAG from pydough.types import ( @@ -2184,7 +2184,7 @@ def relational_test_data(request) -> tuple[CollectionTestInfo, str]: def test_ast_to_relational( relational_test_data: tuple[CollectionTestInfo, str], tpch_node_builder: AstNodeBuilder, - default_config: PyDoughConfigs, + empty_sqlite_tpch_session: PyDoughSession, get_plan_test_filename: Callable[[str], str], update_tests: bool, ) -> None: @@ -2195,7 +2195,7 @@ def test_ast_to_relational( calc_pipeline, file_name = relational_test_data file_path: str = get_plan_test_filename(file_name) collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder) - relational = convert_ast_to_relational(collection, None, default_config) + relational = convert_ast_to_relational(collection, None, empty_sqlite_tpch_session) if update_tests: with open(file_path, "w") as f: f.write(relational.to_tree_string() + "\n") @@ -2320,7 +2320,7 @@ def relational_alternative_config_test_data(request) -> tuple[CollectionTestInfo def test_ast_to_relational_alternative_aggregation_configs( relational_alternative_config_test_data: tuple[CollectionTestInfo, str], tpch_node_builder: AstNodeBuilder, - default_config: PyDoughConfigs, + empty_sqlite_tpch_session: PyDoughSession, get_plan_test_filename: Callable[[str], str], update_tests: bool, ) -> None: @@ -2332,10 +2332,10 @@ def test_ast_to_relational_alternative_aggregation_configs( """ calc_pipeline, file_name = relational_alternative_config_test_data file_path: str = get_plan_test_filename(file_name) - default_config.sum_default_zero = False - default_config.avg_default_zero = True + empty_sqlite_tpch_session.config.sum_default_zero = False + empty_sqlite_tpch_session.config.avg_default_zero = True collection: PyDoughCollectionQDAG = calc_pipeline.build(tpch_node_builder) - relational = convert_ast_to_relational(collection, None, default_config) + relational = convert_ast_to_relational(collection, None, empty_sqlite_tpch_session) if update_tests: with open(file_path, "w") as f: f.write(relational.to_tree_string() + "\n") diff --git a/tests/test_qualification.py b/tests/test_qualification.py index 50280736b..4b2f1d790 100644 --- a/tests/test_qualification.py +++ b/tests/test_qualification.py @@ -8,8 +8,7 @@ import pytest from pydough import init_pydough_context -from pydough.configs import PyDoughConfigs -from pydough.metadata import GraphMetadata +from pydough.configs import PyDoughConfigs, PyDoughSession from pydough.qdag import PyDoughCollectionQDAG, PyDoughQDAG from pydough.unqualified import ( UnqualifiedNode, @@ -63,9 +62,6 @@ impl_tpch_q21, impl_tpch_q22, ) -from tests.testing_utilities import ( - graph_fetcher, -) @pytest.mark.parametrize( @@ -947,16 +943,17 @@ def test_qualify_node_to_ast_string( impl: Callable[..., UnqualifiedNode], answer_tree_str: str, - get_sample_graph: graph_fetcher, - default_config: PyDoughConfigs, + empty_sqlite_tpch_session: PyDoughSession, ) -> None: """ Tests that a PyDough unqualified node can be correctly translated to its qualified DAG version, with the correct string representation. """ - graph: GraphMetadata = get_sample_graph("TPCH") - unqualified: UnqualifiedNode = init_pydough_context(graph)(impl)() - qualified: PyDoughQDAG = qualify_node(unqualified, graph, default_config) + assert empty_sqlite_tpch_session.metadata is not None + unqualified: UnqualifiedNode = init_pydough_context( + empty_sqlite_tpch_session.metadata + )(impl)() + qualified: PyDoughQDAG = qualify_node(unqualified, empty_sqlite_tpch_session) assert isinstance(qualified, PyDoughCollectionQDAG), ( "Expected qualified answer to be a collection, not an expression" ) @@ -1051,7 +1048,7 @@ def test_qualify_node_collation( answer_tree_str: str, collation_default_asc: bool, propagate_collation: bool, - get_sample_graph: graph_fetcher, + empty_sqlite_tpch_session: PyDoughSession, ) -> None: """ Tests that a PyDough unqualified node can be correctly translated to its @@ -1060,9 +1057,12 @@ def test_qualify_node_collation( custom_config: PyDoughConfigs = PyDoughConfigs() custom_config.collation_default_asc = collation_default_asc custom_config.propagate_collation = propagate_collation - graph: GraphMetadata = get_sample_graph("TPCH") - unqualified: UnqualifiedNode = init_pydough_context(graph)(impl)() - qualified: PyDoughQDAG = qualify_node(unqualified, graph, custom_config) + assert empty_sqlite_tpch_session.metadata is not None + empty_sqlite_tpch_session.config = custom_config + unqualified: UnqualifiedNode = init_pydough_context( + empty_sqlite_tpch_session.metadata + )(impl)() + qualified: PyDoughQDAG = qualify_node(unqualified, empty_sqlite_tpch_session) assert isinstance(qualified, PyDoughCollectionQDAG), ( "Expected qualified answer to be a collection, not an expression" ) diff --git a/tests/test_qualification_errors.py b/tests/test_qualification_errors.py index a20478c01..0be9a3aaf 100644 --- a/tests/test_qualification_errors.py +++ b/tests/test_qualification_errors.py @@ -8,15 +8,11 @@ import pytest import pydough -from pydough.configs import PyDoughConfigs -from pydough.metadata import GraphMetadata +from pydough.configs import PyDoughSession from pydough.unqualified import ( UnqualifiedNode, qualify_node, ) -from tests.testing_utilities import ( - graph_fetcher, -) @pytest.mark.parametrize( @@ -209,7 +205,7 @@ def test_qualify_error( pydough_text: str, error_msg: str, - get_sample_graph: graph_fetcher, + empty_sqlite_tpch_session: PyDoughSession, ) -> None: """ Tests that the qualification process correctly raises the expected error @@ -219,10 +215,10 @@ def test_qualify_error( multiple lines, but must end with storing the answers in a variable called `result`. """ - graph: GraphMetadata = get_sample_graph("TPCH") - default_config: PyDoughConfigs = pydough.active_session.config with pytest.raises(Exception, match=re.escape(error_msg)): unqualified: UnqualifiedNode = pydough.from_string( - pydough_text, answer_variable="result", metadata=graph + pydough_text, + answer_variable="result", + metadata=empty_sqlite_tpch_session.metadata, ) - qualify_node(unqualified, graph, default_config) + qualify_node(unqualified, empty_sqlite_tpch_session) diff --git a/tests/test_relational_execution.py b/tests/test_relational_execution.py index 27d727fda..0d2999850 100644 --- a/tests/test_relational_execution.py +++ b/tests/test_relational_execution.py @@ -7,8 +7,7 @@ import pandas as pd import pytest -from pydough.configs import PyDoughConfigs -from pydough.database_connectors import DatabaseContext +from pydough.configs import PyDoughSession from pydough.pydough_operators import ( EQU, SUM, @@ -32,10 +31,7 @@ pytestmark = [pytest.mark.execute] -def test_person_total_salary( - sqlite_people_jobs_context: DatabaseContext, - default_config: PyDoughConfigs, -) -> None: +def test_person_total_salary(sqlite_people_jobs_session: PyDoughSession) -> None: """ Tests a simple join and aggregate to compute the total salary for each person in the PEOPLE table. @@ -91,7 +87,7 @@ def test_person_total_salary( join_type=JoinType.LEFT, ), ) - output: list[Any] = execute_df(result, sqlite_people_jobs_context, default_config) + output: list[Any] = execute_df(result, sqlite_people_jobs_session) people_results: list[str] = [f"Person {i}" for i in range(10)] salary_results: list[float] = [ sum((i + j + 5.7) * 1000 for j in range(2)) for i in range(10) diff --git a/tests/test_relational_execution_tpch.py b/tests/test_relational_execution_tpch.py index 44c0352e5..a55157fde 100644 --- a/tests/test_relational_execution_tpch.py +++ b/tests/test_relational_execution_tpch.py @@ -5,8 +5,7 @@ import pandas as pd import pytest -from pydough.configs import PyDoughConfigs -from pydough.database_connectors import DatabaseContext +from pydough.configs import PyDoughConfigs, PyDoughSession from pydough.relational import RelationalRoot from pydough.sqlglot import execute_df from tests.test_pydough_functions.tpch_outputs import ( @@ -46,12 +45,12 @@ def test_tpch( root: RelationalRoot, output: pd.DataFrame, - sqlite_tpch_db_context: DatabaseContext, + sqlite_tpch_session: PyDoughSession, default_config: PyDoughConfigs, ) -> None: """ Test the example TPC-H relational trees executed on a SQLite database. """ - result = execute_df(root, sqlite_tpch_db_context, default_config) + result = execute_df(root, sqlite_tpch_session) pd.testing.assert_frame_equal(result, output) diff --git a/tests/test_relational_nodes_to_sqlglot.py b/tests/test_relational_nodes_to_sqlglot.py index 285db4f50..69f4062f3 100644 --- a/tests/test_relational_nodes_to_sqlglot.py +++ b/tests/test_relational_nodes_to_sqlglot.py @@ -32,8 +32,6 @@ ) from sqlglot.expressions import Identifier as Ident -from pydough.configs import PyDoughConfigs -from pydough.database_connectors import DatabaseDialect from pydough.pydough_operators import ( ABS, ADD, @@ -73,10 +71,9 @@ ) -@pytest.fixture(scope="module") -def sqlglot_relational_visitor() -> SQLGlotRelationalVisitor: - config: PyDoughConfigs = PyDoughConfigs() - return SQLGlotRelationalVisitor(DatabaseDialect.SQLITE, config) +@pytest.fixture +def sqlglot_relational_visitor(empty_sqlite_tpch_session) -> SQLGlotRelationalVisitor: + return SQLGlotRelationalVisitor(empty_sqlite_tpch_session) @dataclass diff --git a/tests/test_relational_to_sql.py b/tests/test_relational_to_sql.py index 2c144973b..801e9601f 100644 --- a/tests/test_relational_to_sql.py +++ b/tests/test_relational_to_sql.py @@ -8,7 +8,7 @@ import pytest -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughSession from pydough.database_connectors import DatabaseDialect from pydough.pydough_operators import ( ABS, @@ -578,15 +578,13 @@ def test_convert_relation_to_sqlite_sql( test_name: str, get_sql_test_filename: Callable[[str, DatabaseDialect], str], update_tests: bool, - default_config: PyDoughConfigs, + empty_sqlite_tpch_session: PyDoughSession, ) -> None: """ Test converting a relational tree to SQL text in the SQLite dialect. """ file_path: str = get_sql_test_filename(test_name, DatabaseDialect.SQLITE) - created_sql: str = convert_relation_to_sql( - root, DatabaseDialect.SQLITE, default_config - ) + created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session) if update_tests: with open(file_path, "w") as f: f.write(created_sql + "\n") @@ -959,16 +957,14 @@ def test_function_to_sql( test_name: str, get_sql_test_filename: Callable[[str, DatabaseDialect], str], update_tests: bool, - default_config: PyDoughConfigs, + empty_sqlite_tpch_session: PyDoughSession, ) -> None: """ Tests that should be small as we need to just test converting a function to SQL. """ file_path: str = get_sql_test_filename(f"func_{test_name}", DatabaseDialect.ANSI) - created_sql: str = convert_relation_to_sql( - root, DatabaseDialect.SQLITE, default_config - ) + created_sql: str = convert_relation_to_sql(root, empty_sqlite_tpch_session) if update_tests: with open(file_path, "w") as f: f.write(created_sql + "\n") diff --git a/tests/testing_utilities.py b/tests/testing_utilities.py index 634502131..f96f9d91a 100644 --- a/tests/testing_utilities.py +++ b/tests/testing_utilities.py @@ -40,7 +40,7 @@ import pydough import pydough.pydough_operators as pydop from pydough import init_pydough_context, to_df, to_sql -from pydough.configs import PyDoughConfigs +from pydough.configs import PyDoughConfigs, PyDoughSession from pydough.conversion import convert_ast_to_relational from pydough.database_connectors import DatabaseContext from pydough.errors import PyDoughTestingException @@ -1172,14 +1172,15 @@ def run_relational_test( # Run the PyDough code through the pipeline up until it is converted to # a relational plan. - if config is None: - config = pydough.active_session.config - qualified: PyDoughQDAG = qualify_node(root, graph, config) + session: PyDoughSession = PyDoughSession() + session.metadata = graph + session.config = config if config is not None else pydough.active_session.config + qualified: PyDoughQDAG = qualify_node(root, session) assert isinstance(qualified, PyDoughCollectionQDAG), ( "Expected qualified answer to be a collection, not an expression" ) relational: RelationalRoot = convert_ast_to_relational( - qualified, _load_column_selection({"columns": self.columns}), config + qualified, _load_column_selection({"columns": self.columns}), session ) # Either update the reference solution, or compare the generated From 4706344f3d8050204bbc8f7554cc774a4310024e Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 15 Sep 2025 13:32:48 -0400 Subject: [PATCH 2/3] Updating docs [RUN CI] --- documentation/usage.md | 101 +++++++++++++++-------------------------- 1 file changed, 37 insertions(+), 64 deletions(-) diff --git a/documentation/usage.md b/documentation/usage.md index 4387fb79b..b3bef5a7a 100644 --- a/documentation/usage.md +++ b/documentation/usage.md @@ -436,45 +436,23 @@ result = european_countries.CALCULATE(name, n_custs=COUNT(customers)) pydough.to_sql(result, columns=["name", "n_custs"]) ``` -""" -s = ''' -european_countries = nations.WHERE(region.name == "EUROPE") -result = european_countries.CALCULATE(name, n_custs=COUNT(customers)) -''' -pydough.active_session.load_metadata_graph("tests/test_metadata/sample_graphs.json", "TPCH") -u = pydough.from_string(s) -print(pydough.to_sql(result, columns=["name", "n_custs"])) -""" - ```sql -SELECT name, COALESCE(agg_0, 0) AS n_custs -FROM ( - SELECT name, agg_0 - FROM ( - SELECT name, key - FROM ( - SELECT _table_alias_0.name AS name, _table_alias_0.key AS key, _table_alias_1.name AS name_3 - FROM ( - SELECT n_name AS name, n_nationkey AS key, n_regionkey AS region_key FROM main.NATION - ) AS _table_alias_0 - LEFT JOIN ( - SELECT r_name AS name, r_regionkey AS key - FROM main.REGION - ) AS _table_alias_1 - ON region_key = _table_alias_1.key - ) - WHERE name_3 = 'EUROPE' - ) - LEFT JOIN ( - SELECT nation_key, COUNT(*) AS agg_0 - FROM ( - SELECT c_nationkey AS nation_key - FROM main.CUSTOMER - ) - GROUP BY nation_key - ) - ON key = nation_key +WITH _s3 AS ( + SELECT + c_nationkey, + COUNT(*) AS n_rows + FROM tpch.customer + GROUP BY + 1 ) +SELECT + nation.n_name AS name, + _s3.n_rows AS n_custs +FROM tpch.nation AS nation +JOIN tpch.region AS region + ON nation.n_regionkey = region.r_regionkey AND region.r_name = 'EUROPE' +JOIN _s3 AS _s3 + ON _s3.c_nationkey = nation.n_nationkey ``` See the [demo notebooks](../demos/README.md) for more instances of how to use the `to_sql` API. @@ -628,41 +606,35 @@ The value of `sql` is the following SQL query text as a Python string: ```sql WITH _s7 AS ( SELECT - ROUND( - COALESCE( - SUM( - lineitem.l_extendedprice * ( - 1 - lineitem.l_discount - ) * ( - 1 - lineitem.l_tax - ) - lineitem.l_quantity * partsupp.ps_supplycost - ), - 0 - ), - 2 - ) AS revenue_year, - partsupp.ps_suppkey + partsupp.ps_suppkey, + SUM( + lineitem.l_extendedprice * ( + 1 - lineitem.l_discount + ) * ( + 1 - lineitem.l_tax + ) - lineitem.l_quantity * partsupp.ps_supplycost + ) AS sum_rev FROM main.partsupp AS partsupp JOIN main.part AS part ON part.p_name LIKE 'coral%' AND part.p_partkey = partsupp.ps_partkey JOIN main.lineitem AS lineitem - ON CAST(STRFTIME('%Y', lineitem.l_shipdate) AS INTEGER) = 1996 + ON EXTRACT(YEAR FROM CAST(lineitem.l_shipdate AS DATETIME)) = 1996 AND lineitem.l_partkey = partsupp.ps_partkey AND lineitem.l_shipmode = 'TRUCK' AND lineitem.l_suppkey = partsupp.ps_suppkey GROUP BY - partsupp.ps_suppkey + 1 ) SELECT supplier.s_name AS name, - _s7.revenue_year + ROUND(COALESCE(_s7.sum_rev, 0), 2) AS revenue_year FROM main.supplier AS supplier JOIN main.nation AS nation ON nation.n_name = 'JAPAN' AND nation.n_nationkey = supplier.s_nationkey JOIN _s7 AS _s7 ON _s7.ps_suppkey = supplier.s_suppkey ORDER BY - revenue_year DESC + 2 DESC LIMIT 5 ``` @@ -700,27 +672,27 @@ The value of `sql` is the following SQL query text as a Python string: ```sql WITH _s1 AS ( SELECT - COALESCE(SUM(o_totalprice), 0) AS total, + o_custkey, COUNT(*) AS n_rows, - o_custkey + SUM(o_totalprice) AS sum_o_totalprice FROM main.orders WHERE - o_orderdate < '1997-01-01' - AND o_orderdate >= '1996-01-01' + o_orderdate < CAST('1997-01-01' AS DATE) + AND o_orderdate >= CAST('1996-01-01' AS DATE) AND o_orderpriority = '1-URGENT' AND o_totalprice > 100000 GROUP BY - o_custkey + 1 ) SELECT customer.c_name AS name, _s1.n_rows AS n_orders, - _s1.total + _s1.sum_o_totalprice AS total FROM main.customer AS customer JOIN _s1 AS _s1 ON _s1.o_custkey = customer.c_custkey ORDER BY - total DESC + 3 DESC ``` @@ -788,7 +760,7 @@ The `explain` API is a more generic explanation interface that can be called on - A specific property within a specific collection within a metadata graph object (can be accessed as `graph["collection_name"]["property_name"]`) - The PyDough code for a collection that could have `to_sql` or `to_df` called on it. -The `explain` API also has an optional `verbose` argument (default=False) that enables displaying additional information. +The `explain` API also has an optional `verbose` argument (default=False) that enables displaying additional information. It also has an optional `session` argument to specify what configs etc. to use when explaining certain terms (if not provided, uses `pydough.active_session`). Below are examples of each of these behaviors, using a knowledge graph for the TPCH schema. @@ -1006,7 +978,8 @@ The `explain` API is limited in that it can only be called on complete PyDough c To handle cases where you need to learn about a term within a collection, you can use the `explain_term` API. The first argument to `explain_term` is PyDough code for a collection, which can have `explain` called on it, and the second is PyDough code for a term that can be evaluated within the context of that collection (e.g. a scalar term of the collection, or one of its sub-collections). -The `explain_term` API also has a `verbose` keyword argument (default False) to specify whether to include a more detailed explanation, as opposed to a more compact summary. +The `explain_term` API also has a `verbose` keyword argument (default False) to specify whether to include a more detailed explanation, as opposed to a more compact summary. The `explain_term` API also has an optional `verbose` argument (default=False) that enables displaying additional information. It also has an optional `session` argument to specify what configs etc. to use when explaining certain terms (if not provided, uses `pydough.active_session`). + Below are examples of using `explain_term`, using a knowledge graph for the TPCH schema. For each of these examples, `european_countries` is the "context" collection, which could have `to_sql` or `to_df` called on it, and `term` is the term being explained with regards to `european_countries`. From f3200bd3efac3ca1e7d258bba35240acd199bda4 Mon Sep 17 00:00:00 2001 From: knassre-bodo Date: Mon, 15 Sep 2025 13:57:37 -0400 Subject: [PATCH 3/3] Updating docs [RUN CI] --- tests/test_relational_execution_tpch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_relational_execution_tpch.py b/tests/test_relational_execution_tpch.py index a55157fde..156c1400e 100644 --- a/tests/test_relational_execution_tpch.py +++ b/tests/test_relational_execution_tpch.py @@ -5,7 +5,7 @@ import pandas as pd import pytest -from pydough.configs import PyDoughConfigs, PyDoughSession +from pydough.configs import PyDoughSession from pydough.relational import RelationalRoot from pydough.sqlglot import execute_df from tests.test_pydough_functions.tpch_outputs import ( @@ -46,7 +46,6 @@ def test_tpch( root: RelationalRoot, output: pd.DataFrame, sqlite_tpch_session: PyDoughSession, - default_config: PyDoughConfigs, ) -> None: """ Test the example TPC-H relational trees executed on a