Skip to content

Commit f1a9961

Browse files
authored
Add a session keyword argument to to_sql and to_df (#427)
Resolves #421.
1 parent e204a4a commit f1a9961

27 files changed

+311
-291
lines changed

documentation/usage.md

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ The `to_sql` API takes in PyDough code and transforms it into SQL query text wit
453453
- `metadata`: the PyDough knowledge graph to use for the conversion (if omitted, `pydough.active_session.metadata` is used instead).
454454
- `config`: the PyDough configuration settings to use for the conversion (if omitted, `pydough.active_session.config` is used instead).
455455
- `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation.
456+
- `session`: a PyDough session object which, if provided, is used instead of `pydough.active_session` or the `metadata` / `config` / `database` arguments. Note: this argument cannot be used alongside those arguments.
456457

457458
Below is an example of using `pydough.to_sql` and the output (the SQL output may be outdated if PyDough's SQL conversion process has been updated):
458459

@@ -464,34 +465,22 @@ pydough.to_sql(result, columns=["name", "n_custs"])
464465
```
465466

466467
```sql
467-
SELECT name, COALESCE(agg_0, 0) AS n_custs
468-
FROM (
469-
SELECT name, agg_0
470-
FROM (
471-
SELECT name, key
472-
FROM (
473-
SELECT _table_alias_0.name AS name, _table_alias_0.key AS key, _table_alias_1.name AS name_3
474-
FROM (
475-
SELECT n_name AS name, n_nationkey AS key, n_regionkey AS region_key FROM main.NATION
476-
) AS _table_alias_0
477-
LEFT JOIN (
478-
SELECT r_name AS name, r_regionkey AS key
479-
FROM main.REGION
480-
) AS _table_alias_1
481-
ON region_key = _table_alias_1.key
482-
)
483-
WHERE name_3 = 'EUROPE'
484-
)
485-
LEFT JOIN (
486-
SELECT nation_key, COUNT(*) AS agg_0
487-
FROM (
488-
SELECT c_nationkey AS nation_key
489-
FROM main.CUSTOMER
490-
)
491-
GROUP BY nation_key
492-
)
493-
ON key = nation_key
468+
WITH _s3 AS (
469+
SELECT
470+
c_nationkey,
471+
COUNT(*) AS n_rows
472+
FROM tpch.customer
473+
GROUP BY
474+
1
494475
)
476+
SELECT
477+
nation.n_name AS name,
478+
_s3.n_rows AS n_custs
479+
FROM tpch.nation AS nation
480+
JOIN tpch.region AS region
481+
ON nation.n_regionkey = region.r_regionkey AND region.r_name = 'EUROPE'
482+
JOIN _s3 AS _s3
483+
ON _s3.c_nationkey = nation.n_nationkey
495484
```
496485

497486
See the [demo notebooks](../demos/README.md) for more instances of how to use the `to_sql` API.
@@ -506,6 +495,7 @@ The `to_df` API does all the same steps as the [`to_sql` API](#pydoughto_sql), b
506495
- `metadata`: the PyDough knowledge graph to use for the conversion (if omitted, `pydough.active_session.metadata` is used instead).
507496
- `config`: the PyDough configuration settings to use for the conversion (if omitted, `pydough.active_session.config` is used instead).
508497
- `database`: the database context to use for the conversion (if omitted, `pydough.active_session.database` is used instead). The database context matters because it controls which SQL dialect is used for the translation.
498+
- `session`: a PyDough session object which, if provided, is used instead of `pydough.active_session` or the `metadata` / `config` / `database` arguments. Note: this argument cannot be used alongside those arguments.
509499
- `display_sql`: displays the sql before executing in a logger.
510500

511501
Below is an example of using `pydough.to_df` and the output, attached to a sqlite database containing data for the TPC-H schema:
@@ -644,41 +634,35 @@ The value of `sql` is the following SQL query text as a Python string:
644634
```sql
645635
WITH _s7 AS (
646636
SELECT
647-
ROUND(
648-
COALESCE(
649-
SUM(
650-
lineitem.l_extendedprice * (
651-
1 - lineitem.l_discount
652-
) * (
653-
1 - lineitem.l_tax
654-
) - lineitem.l_quantity * partsupp.ps_supplycost
655-
),
656-
0
657-
),
658-
2
659-
) AS revenue_year,
660-
partsupp.ps_suppkey
637+
partsupp.ps_suppkey,
638+
SUM(
639+
lineitem.l_extendedprice * (
640+
1 - lineitem.l_discount
641+
) * (
642+
1 - lineitem.l_tax
643+
) - lineitem.l_quantity * partsupp.ps_supplycost
644+
) AS sum_rev
661645
FROM main.partsupp AS partsupp
662646
JOIN main.part AS part
663647
ON part.p_name LIKE 'coral%' AND part.p_partkey = partsupp.ps_partkey
664648
JOIN main.lineitem AS lineitem
665-
ON CAST(STRFTIME('%Y', lineitem.l_shipdate) AS INTEGER) = 1996
649+
ON EXTRACT(YEAR FROM CAST(lineitem.l_shipdate AS DATETIME)) = 1996
666650
AND lineitem.l_partkey = partsupp.ps_partkey
667651
AND lineitem.l_shipmode = 'TRUCK'
668652
AND lineitem.l_suppkey = partsupp.ps_suppkey
669653
GROUP BY
670-
partsupp.ps_suppkey
654+
1
671655
)
672656
SELECT
673657
supplier.s_name AS name,
674-
_s7.revenue_year
658+
ROUND(COALESCE(_s7.sum_rev, 0), 2) AS revenue_year
675659
FROM main.supplier AS supplier
676660
JOIN main.nation AS nation
677661
ON nation.n_name = 'JAPAN' AND nation.n_nationkey = supplier.s_nationkey
678662
JOIN _s7 AS _s7
679663
ON _s7.ps_suppkey = supplier.s_suppkey
680664
ORDER BY
681-
revenue_year DESC
665+
2 DESC
682666
LIMIT 5
683667
```
684668

@@ -716,27 +700,27 @@ The value of `sql` is the following SQL query text as a Python string:
716700
```sql
717701
WITH _s1 AS (
718702
SELECT
719-
COALESCE(SUM(o_totalprice), 0) AS total,
703+
o_custkey,
720704
COUNT(*) AS n_rows,
721-
o_custkey
705+
SUM(o_totalprice) AS sum_o_totalprice
722706
FROM main.orders
723707
WHERE
724-
o_orderdate < '1997-01-01'
725-
AND o_orderdate >= '1996-01-01'
708+
o_orderdate < CAST('1997-01-01' AS DATE)
709+
AND o_orderdate >= CAST('1996-01-01' AS DATE)
726710
AND o_orderpriority = '1-URGENT'
727711
AND o_totalprice > 100000
728712
GROUP BY
729-
o_custkey
713+
1
730714
)
731715
SELECT
732716
customer.c_name AS name,
733717
_s1.n_rows AS n_orders,
734-
_s1.total
718+
_s1.sum_o_totalprice AS total
735719
FROM main.customer AS customer
736720
JOIN _s1 AS _s1
737721
ON _s1.o_custkey = customer.c_custkey
738722
ORDER BY
739-
total DESC
723+
3 DESC
740724
```
741725

742726
<!-- TOC --><a name="exploration-apis"></a>
@@ -804,7 +788,9 @@ The `explain` API is a more generic explanation interface that can be called on
804788
- A specific property within a specific collection within a metadata graph object (can be accessed as `graph["collection_name"]["property_name"]`)
805789
- The PyDough code for a collection that could have `to_sql` or `to_df` called on it.
806790

807-
The `explain` API also has an optional `verbose` argument (default=False) that enables displaying additional information.
791+
The `explain` API has the following optional arguments:
792+
* `verbose` (default False): specifies whether to include a more detailed explanation, as opposed to a more compact summary.
793+
* `session` (default None): if provided, specifies what configs etc. to use when explaining PyDough code objects (if not provided, uses `pydough.active_session`).
808794

809795
Below are examples of each of these behaviors, using a knowledge graph for the TPCH schema.
810796

@@ -1022,7 +1008,9 @@ The `explain` API is limited in that it can only be called on complete PyDough c
10221008

10231009
To handle cases where you need to learn about a term within a collection, you can use the `explain_term` API. The first argument to `explain_term` is PyDough code for a collection, which can have `explain` called on it, and the second is PyDough code for a term that can be evaluated within the context of that collection (e.g. a scalar term of the collection, or one of its sub-collections).
10241010

1025-
The `explain_term` API also has a `verbose` keyword argument (default False) to specify whether to include a more detailed explanation, as opposed to a more compact summary.
1011+
The `explain_term` API has the following optional arguments:
1012+
* `verbose` (default False): specifies whether to include a more detailed explanation, as opposed to a more compact summary.
1013+
* `session` (default None): if provided, specifies what configs etc. to use when explaining certain terms (if not provided, uses `pydough.active_session`).
10261014

10271015
Below are examples of using `explain_term`, using a knowledge graph for the TPCH schema. For each of these examples, `european_countries` is the "context" collection, which could have `to_sql` or `to_df` called on it, and `term` is the term being explained with regards to `european_countries`.
10281016

pydough/conversion/agg_split.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
import pydough.pydough_operators as pydop
10-
from pydough.configs import PyDoughConfigs
10+
from pydough.configs import PyDoughSession
1111
from pydough.relational import (
1212
Aggregate,
1313
CallExpression,
@@ -51,15 +51,15 @@
5151
"""
5252

5353

54-
def decompose_aggregations(node: Aggregate, config: PyDoughConfigs) -> RelationalNode:
54+
def decompose_aggregations(node: Aggregate, session: PyDoughSession) -> RelationalNode:
5555
"""
5656
Splits up an aggregate node into an aggregate followed by a projection when
5757
the aggregate contains 1+ calls to functions that can be split into 1+
5858
calls to partial aggregates, e.g. how AVG(X) = SUM(X)/COUNT(X).
5959
6060
Args:
6161
`node`: the aggregate node to be decomposed.
62-
`config`: the current configuration settings.
62+
`session`: the PyDough session used during the transformation.
6363
6464
Returns:
6565
The projection node on top of the new aggregate, overall containing the
@@ -110,7 +110,7 @@ def decompose_aggregations(node: Aggregate, config: PyDoughConfigs) -> Relationa
110110
)
111111
# If the config specifies that the default value for AVG should be
112112
# zero, wrap the division in a DEFAULT_TO call.
113-
if config.avg_default_zero:
113+
if session.config.avg_default_zero:
114114
avg_call = CallExpression(
115115
pydop.DEFAULT_TO,
116116
agg.data_type,
@@ -277,7 +277,7 @@ def transpose_aggregate_join(
277277

278278

279279
def attempt_join_aggregate_transpose(
280-
node: Aggregate, join: Join, config: PyDoughConfigs
280+
node: Aggregate, join: Join, session: PyDoughSession
281281
) -> tuple[RelationalNode, bool]:
282282
"""
283283
Determine whether the aggregate join transpose operation can occur, and if
@@ -392,7 +392,7 @@ def attempt_join_aggregate_transpose(
392392
for col in node.aggregations.values():
393393
if col.op in decomposable_aggfuncs:
394394
return split_partial_aggregates(
395-
decompose_aggregations(node, config), config
395+
decompose_aggregations(node, session), session
396396
), False
397397

398398
# Keep a dictionary for the projection columns that will be used to post-process
@@ -477,15 +477,15 @@ def attempt_join_aggregate_transpose(
477477
# top of the top aggregate.
478478
if need_projection:
479479
new_node: RelationalNode = node.copy(
480-
inputs=[split_partial_aggregates(input, config) for input in node.inputs]
480+
inputs=[split_partial_aggregates(input, session) for input in node.inputs]
481481
)
482482
return Project(new_node, projection_columns), False
483483
else:
484484
return node, True
485485

486486

487487
def split_partial_aggregates(
488-
node: RelationalNode, config: PyDoughConfigs
488+
node: RelationalNode, session: PyDoughSession
489489
) -> RelationalNode:
490490
"""
491491
Splits partial aggregates above joins into two aggregates, one above the
@@ -494,19 +494,21 @@ def split_partial_aggregates(
494494
495495
Args:
496496
`node`: the root node of the relational plan to be transformed.
497-
`config`: the current configuration settings.
497+
`session`: the PyDough session used during the transformation.
498498
499499
Returns:
500500
The transformed node. The transformation is also done-in-place.
501501
"""
502502
# If the aggregate+join pattern is detected, attempt to do the transpose.
503503
handle_inputs: bool = True
504504
if isinstance(node, Aggregate) and isinstance(node.input, Join):
505-
node, handle_inputs = attempt_join_aggregate_transpose(node, node.input, config)
505+
node, handle_inputs = attempt_join_aggregate_transpose(
506+
node, node.input, session
507+
)
506508

507509
# If needed, recursively invoke the procedure on all inputs to the node.
508510
if handle_inputs:
509511
node = node.copy(
510-
inputs=[split_partial_aggregates(input, config) for input in node.inputs]
512+
inputs=[split_partial_aggregates(input, session) for input in node.inputs]
511513
)
512514
return node

pydough/conversion/filter_pushdown.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
import pydough.pydough_operators as pydop
9-
from pydough.configs import PyDoughConfigs
9+
from pydough.configs import PyDoughSession
1010
from pydough.relational import (
1111
Aggregate,
1212
CallExpression,
@@ -66,7 +66,7 @@ class FilterPushdownShuttle(RelationalShuttle):
6666
cannot be pushed further.
6767
"""
6868

69-
def __init__(self, configs: PyDoughConfigs):
69+
def __init__(self, session: PyDoughSession):
7070
# The set of filters that are currently being pushed down. When
7171
# visit_xxx is called, it is presumed that the set of conditions in
7272
# self.filters are the conditions that can be pushed down as far as the
@@ -76,7 +76,7 @@ def __init__(self, configs: PyDoughConfigs):
7676
# simplification logic to aid in advanced filter predicate inference,
7777
# such as determining that a left join is redundant because if the RHS
7878
# column is null then the filter will always be false.
79-
self.simplifier: SimplificationShuttle = SimplificationShuttle(configs)
79+
self.simplifier: SimplificationShuttle = SimplificationShuttle(session)
8080

8181
def reset(self):
8282
self.filters = set()
@@ -307,7 +307,7 @@ def visit_empty_singleton(self, empty_singleton: EmptySingleton) -> RelationalNo
307307
return self.flush_remaining_filters(empty_singleton, self.filters, set())
308308

309309

310-
def push_filters(node: RelationalNode, configs: PyDoughConfigs) -> RelationalNode:
310+
def push_filters(node: RelationalNode, session: PyDoughSession) -> RelationalNode:
311311
"""
312312
Transpose filter conditions down as far as possible.
313313
@@ -321,5 +321,5 @@ def push_filters(node: RelationalNode, configs: PyDoughConfigs) -> RelationalNod
321321
the node or into one of its inputs, or possibly both if there are
322322
multiple filters.
323323
"""
324-
pusher: FilterPushdownShuttle = FilterPushdownShuttle(configs)
324+
pusher: FilterPushdownShuttle = FilterPushdownShuttle(session)
325325
return node.accept_shuttle(pusher)

pydough/conversion/hybrid_translator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Iterable
88

99
import pydough.pydough_operators as pydop
10-
from pydough.configs import PyDoughConfigs
10+
from pydough.configs import PyDoughSession
1111
from pydough.database_connectors import DatabaseDialect
1212
from pydough.errors import PyDoughSQLException
1313
from pydough.metadata import (
@@ -81,8 +81,8 @@ class HybridTranslator:
8181
Class used to translate PyDough QDAG nodes into the HybridTree structure.
8282
"""
8383

84-
def __init__(self, configs: PyDoughConfigs, dialect: DatabaseDialect):
85-
self.configs = configs
84+
def __init__(self, session: PyDoughSession):
85+
self.session = session
8686
# An index used for creating fake column names for aliases
8787
self.alias_counter: int = 0
8888
# A stack where each element is a hybrid tree being derived
@@ -92,7 +92,7 @@ def __init__(self, configs: PyDoughConfigs, dialect: DatabaseDialect):
9292
# If True, rewrites MEDIAN calls into an average of the 1-2 median rows
9393
# or rewrites QUANTILE calls to select the first qualifying row,
9494
# both derived from window functions, otherwise leaves as-is.
95-
self.rewrite_median_quantile: bool = dialect not in {
95+
self.rewrite_median_quantile: bool = session.database.dialect not in {
9696
DatabaseDialect.ANSI,
9797
DatabaseDialect.SNOWFLAKE,
9898
DatabaseDialect.POSTGRES,
@@ -484,8 +484,8 @@ def postprocess_agg_output(
484484
# COUNT/NDISTINCT for left joins since the semantics of those functions
485485
# never allow returning NULL.
486486
if (
487-
(agg_call.operator == pydop.SUM and self.configs.sum_default_zero)
488-
or (agg_call.operator == pydop.AVG and self.configs.avg_default_zero)
487+
(agg_call.operator == pydop.SUM and self.session.config.sum_default_zero)
488+
or (agg_call.operator == pydop.AVG and self.session.config.avg_default_zero)
489489
or (
490490
agg_call.operator in (pydop.COUNT, pydop.NDISTINCT)
491491
and joins_can_nullify

0 commit comments

Comments
 (0)