Skip to content

Commit 0b84cbf

Browse files
Add GETPART function to PyDough (#381)
Co-authored-by: knassre-bodo <[email protected]>
1 parent 92e453f commit 0b84cbf

26 files changed

+3066
-14
lines changed

documentation/functions.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Below is the list of every function/operator currently supported in PyDough as a
2727
* [STRIP](#strip)
2828
* [REPLACE](#replace)
2929
* [STRCOUNT](#strcount)
30+
* [GETPART](#getpart)
3031
- [Datetime Functions](#datetime-functions)
3132
* [DATETIME](#datetime)
3233
* [YEAR](#year)
@@ -438,6 +439,46 @@ Customers.CALCULATE(count_substring= STRCOUNT(name, "")) # returns 0 by default
438439
| `'Alex Rodriguez'`| `STRCOUNT('Alex Rodriguez', 'e')`| `2` |
439440
| `'Hello World'`| `STRCOUNT('Hello World', 'll')` | `1` |
440441

442+
443+
<!-- TOC --><a name="getpart"></a>
444+
445+
### GETPART
446+
447+
The `GETPART` function extracts the N-th part from a string, splitting it by a specified delimiter.
448+
449+
- The first argument is the input string to split.
450+
- The second argument is the delimiter string.
451+
- The third argument is the index of the part to extract. This index can be positive (counting from the start, 0-based) or negative (counting from the end, -1 is the last part).
452+
453+
If the index is out of range, `GETPART` returns `None`. If the delimiter is an empty string, the function will not split the input string and the first part will be the entire string.
454+
455+
```py
456+
# Extracts the first name from a full name
457+
Customers.CALCULATE(first_name = GETPART(name, " ", 1))
458+
459+
# Extracts the last name from a full name
460+
Customers.CALCULATE(last_name = GETPART(name, " ", -1))
461+
462+
# Extracts the second part from a hyphen-separated string
463+
Parts.CALCULATE(second_code = GETPART(code, "-", 2))
464+
```
465+
466+
| **Input String** | **Delimiter** | **Index** | **GETPART Result** |
467+
|---------------------- |-------------- |-----------|--------------------|
468+
| `"Alex Rodriguez"` | `" "` | `1` | `"Alex"` |
469+
| `"Alex Rodriguez"` | `" "` | `0` | `"Alex"` |
470+
| `"Alex Rodriguez"` | `" "` | `2` | `"Rodriguez"` |
471+
| `"Alex Rodriguez"` | `" "` | `-1` | `"Rodriguez"` |
472+
| `"Alex Rodriguez"` | `""` | `1` | `"Alex Rodriguez"` |
473+
| `"a-b-c-d"` | `"-"` | `3` | `"c"` |
474+
| `"a-b-c-d"` | `"-"` | `-2` | `"c"` |
475+
| `"a-b-c-d"` | `"-"` | `5` | `None` |
476+
| `"a-b-c-d"` | `"-"` | `-5` | `None` |
477+
478+
> [!NOTE]
479+
> - Indexing is one-based from the start and negative indices count from the end.
480+
> - The 0 index will be treated as 1, returning the first part.
481+
441482
<!-- TOC --><a name="datetime-functions"></a>
442483

443484
## Datetime Functions

pydough/pydough_operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"FLOAT",
3636
"FLOOR",
3737
"GEQ",
38+
"GETPART",
3839
"GRT",
3940
"HAS",
4041
"HASNOT",
@@ -139,6 +140,7 @@
139140
FLOAT,
140141
FLOOR,
141142
GEQ,
143+
GETPART,
142144
GRT,
143145
HAS,
144146
HASNOT,

pydough/pydough_operators/expression_operators/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ These functions must be called on singular data as a function.
8484
- `FIND`: returns the index(0-indexed) of the first occurrence of the second argument within the first argument, or -1 if the second argument is not found.
8585
- `STRIP`: returns the first argument with all leading and trailing whitespace removed, including newlines, tabs, and spaces. If the second argument is provided, it is used as the set of characters to remove from the leading and trailing ends of the first argument.
8686
- `REPLACE`: returns the first argument with all instances of the second argument replaced by the third argument. If the third argument is not provided, all instances of the second argument are removed from the first argument.
87-
8887
- `STRCOUNT`: returns how many times the second argument appears in the first argument. If one or both arguments are an empty string the return would be 0
88+
- `GETPART`: extracts the N-th part from a string, splitting it by a specified delimiter. The first argument is the input string, the second argument is the delimiter string, and the third argument is the index of the part to extract (can be positive for counting from the start, or negative for counting from the end; 1-based indexing). If the index is out of range, returns a `None` value. If the delimiter is an empty string the string will not be splitted, the first part is the entire string.
8989

9090
##### Datetime Functions
9191

pydough/pydough_operators/expression_operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"FLOAT",
3333
"FLOOR",
3434
"GEQ",
35+
"GETPART",
3536
"GRT",
3637
"HAS",
3738
"HASNOT",
@@ -131,6 +132,7 @@
131132
FLOAT,
132133
FLOOR,
133134
GEQ,
135+
GETPART,
134136
GRT,
135137
HAS,
136138
HASNOT,

pydough/pydough_operators/expression_operators/registered_expression_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"FLOAT",
2828
"FLOOR",
2929
"GEQ",
30+
"GETPART",
3031
"GRT",
3132
"HAS",
3233
"HASNOT",
@@ -329,3 +330,6 @@
329330
STRING = ExpressionFunctionOperator(
330331
"STRING", False, RequireArgRange(1, 2), ConstantType(StringType())
331332
)
333+
GETPART = ExpressionFunctionOperator(
334+
"GETPART", False, RequireNumArgs(3), ConstantType(StringType())
335+
)

pydough/sqlglot/execute_relational.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sqlglot.dialects import Dialect as SQLGlotDialect
1010
from sqlglot.dialects import SQLite as SQLiteDialect
1111
from sqlglot.errors import SqlglotError
12-
from sqlglot.expressions import Alias, Column, Select, Table
12+
from sqlglot.expressions import Alias, Column, Select, Table, With
1313
from sqlglot.expressions import Expression as SQLGlotExpression
1414
from sqlglot.optimizer import find_all_in_scope
1515
from sqlglot.optimizer.annotate_types import annotate_types
@@ -19,7 +19,6 @@
1919
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
2020
from sqlglot.optimizer.normalize import normalize
2121
from sqlglot.optimizer.optimize_joins import optimize_joins
22-
from sqlglot.optimizer.pushdown_projections import pushdown_projections
2322
from sqlglot.optimizer.qualify import qualify
2423
from sqlglot.optimizer.simplify import simplify
2524
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
@@ -37,6 +36,7 @@
3736

3837
from .override_merge_subqueries import merge_subqueries
3938
from .override_pushdown_predicates import pushdown_predicates
39+
from .override_pushdown_projections import pushdown_projections
4040
from .sqlglot_relational_visitor import SQLGlotRelationalVisitor
4141

4242
__all__ = ["convert_relation_to_sql", "execute_df"]
@@ -98,7 +98,11 @@ def apply_sqlglot_optimizer(
9898

9999
# Rewrite sqlglot AST to have normalized and qualified tables and columns.
100100
glot_expr = qualify(
101-
glot_expr, dialect=dialect, quote_identifiers=False, isolate_tables=True
101+
glot_expr,
102+
dialect=dialect,
103+
quote_identifiers=False,
104+
isolate_tables=True,
105+
validate_qualify_columns=False,
102106
)
103107

104108
# Rewrite sqlglot AST to remove unused columns projections.
@@ -111,14 +115,16 @@ def apply_sqlglot_optimizer(
111115
# Convert scalar subqueries into cross joins.
112116
# Convert correlated or vectorized subqueries into a group by so it is not
113117
# a many to many left join.
114-
glot_expr = unnest_subqueries(glot_expr)
118+
# PyDough skips this step if there are any recursive CTEs in the query, due
119+
# to flaws in how SQLGlot handles such subqueries.
120+
if not any(e.args.get("recursive") for e in glot_expr.find_all(With)):
121+
glot_expr = unnest_subqueries(glot_expr)
115122

116123
# limit clauses, which is not correct.
117124
# Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS.
118125
glot_expr = pushdown_predicates(glot_expr, dialect=dialect)
119126

120127
# Removes cross joins if possible and reorder joins based on predicate
121-
# dependencies.
122128
glot_expr = optimize_joins(glot_expr)
123129

124130
# Rewrite derived tables as CTES, deduplicating if possible.
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
Overridden version of the pushdown_projections.py file from sqlglot.
3+
"""
4+
5+
from collections import defaultdict
6+
7+
from sqlglot import exp
8+
from sqlglot.optimizer.pushdown_projections import SELECT_ALL, _remove_unused_selections
9+
from sqlglot.optimizer.scope import Scope, traverse_scope
10+
from sqlglot.schema import ensure_schema
11+
12+
# ruff: noqa
13+
# mypy: ignore-errors
14+
# ruff & mypy should not try to typecheck or verify any of this
15+
16+
17+
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
18+
"""
19+
Rewrite sqlglot AST to remove unused columns projections.
20+
21+
Example:
22+
>>> import sqlglot
23+
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
24+
>>> expression = sqlglot.parse_one(sql)
25+
>>> pushdown_projections(expression).sql()
26+
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
27+
28+
Args:
29+
expression (sqlglot.Expression): expression to optimize
30+
remove_unused_selections (bool): remove selects that are unused
31+
Returns:
32+
sqlglot.Expression: optimized expression
33+
"""
34+
# Map of Scope to all columns being selected by outer queries.
35+
schema = ensure_schema(schema)
36+
source_column_alias_count = {}
37+
referenced_columns = defaultdict(set)
38+
39+
# We build the scope tree (which is traversed in DFS postorder), then iterate
40+
# over the result in reverse order. This should ensure that the set of selected
41+
# columns for a particular scope are completely build by the time we get to it.
42+
for scope in reversed(traverse_scope(expression)):
43+
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
44+
alias_count = source_column_alias_count.get(scope, 0)
45+
46+
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
47+
# PyDough Change: also include ANY set op
48+
if scope.expression.args.get("distinct") or isinstance(
49+
scope.expression, exp.SetOperation
50+
):
51+
parent_selections = {SELECT_ALL}
52+
53+
if isinstance(scope.expression, exp.SetOperation):
54+
left, right = scope.union_scopes
55+
referenced_columns[left] = parent_selections
56+
57+
if any(select.is_star for select in right.expression.selects):
58+
referenced_columns[right] = parent_selections
59+
elif not any(select.is_star for select in left.expression.selects):
60+
if scope.expression.args.get("by_name"):
61+
referenced_columns[right] = referenced_columns[left]
62+
else:
63+
referenced_columns[right] = [
64+
right.expression.selects[i].alias_or_name
65+
for i, select in enumerate(left.expression.selects)
66+
if SELECT_ALL in parent_selections
67+
or select.alias_or_name in parent_selections
68+
]
69+
70+
if isinstance(scope.expression, exp.Select):
71+
if remove_unused_selections:
72+
_remove_unused_selections(scope, parent_selections, schema, alias_count)
73+
74+
if scope.expression.is_star:
75+
continue
76+
77+
# Group columns by source name
78+
selects = defaultdict(set)
79+
for col in scope.columns:
80+
table_name = col.table
81+
col_name = col.name
82+
selects[table_name].add(col_name)
83+
84+
# Push the selected columns down to the next scope
85+
for name, (node, source) in scope.selected_sources.items():
86+
if isinstance(source, Scope):
87+
columns = (
88+
{SELECT_ALL} if scope.pivots else selects.get(name) or set()
89+
)
90+
referenced_columns[source].update(columns)
91+
92+
column_aliases = node.alias_column_names
93+
if column_aliases:
94+
source_column_alias_count[source] = len(column_aliases)
95+
96+
return expression

pydough/sqlglot/sqlglot_relational_expression_visitor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
the relation Tree to a single SQLGlot query component.
44
"""
55

6+
__all__ = ["SQLGlotRelationalExpressionVisitor"]
7+
68
import datetime
79
import warnings
10+
from typing import TYPE_CHECKING
811

912
import sqlglot.expressions as sqlglot_expressions
1013
from sqlglot.expressions import Expression as SQLGlotExpression
@@ -28,7 +31,8 @@
2831
from .sqlglot_helpers import set_glot_alias
2932
from .transform_bindings import BaseTransformBindings, bindings_from_dialect
3033

31-
__all__ = ["SQLGlotRelationalExpressionVisitor"]
34+
if TYPE_CHECKING:
35+
from .sqlglot_relational_visitor import SQLGlotRelationalVisitor
3236

3337

3438
class SQLGlotRelationalExpressionVisitor(RelationalExpressionVisitor):
@@ -42,14 +46,18 @@ def __init__(
4246
dialect: DatabaseDialect,
4347
correlated_names: dict[str, str],
4448
config: PyDoughConfigs,
49+
relational_visitor: "SQLGlotRelationalVisitor",
4550
) -> None:
4651
# Keep a stack of SQLGlot expressions so we can build up
4752
# intermediate results.
4853
self._stack: list[SQLGlotExpression] = []
4954
self._dialect: DatabaseDialect = dialect
5055
self._correlated_names: dict[str, str] = correlated_names
5156
self._config: PyDoughConfigs = config
52-
self._bindings: BaseTransformBindings = bindings_from_dialect(dialect, config)
57+
self._relational_visitor: SQLGlotRelationalVisitor = relational_visitor
58+
self._bindings: BaseTransformBindings = bindings_from_dialect(
59+
dialect, config, self._relational_visitor
60+
)
5361

5462
def reset(self) -> None:
5563
"""

pydough/sqlglot/sqlglot_relational_visitor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def __init__(
7979
self._dialect: DatabaseDialect = dialect
8080
self._correlated_names: dict[str, str] = {}
8181
self._expr_visitor: SQLGlotRelationalExpressionVisitor = (
82-
SQLGlotRelationalExpressionVisitor(dialect, self._correlated_names, config)
82+
SQLGlotRelationalExpressionVisitor(
83+
dialect, self._correlated_names, config, self
84+
)
8385
)
8486
self._alias_modifier: ColumnReferenceInputNameModifier = (
8587
ColumnReferenceInputNameModifier()

pydough/sqlglot/transform_bindings/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,22 @@
55

66
__all__ = ["BaseTransformBindings", "SQLiteTransformBindings", "bindings_from_dialect"]
77

8+
from typing import TYPE_CHECKING
9+
810
from pydough.configs import PyDoughConfigs
911
from pydough.database_connectors import DatabaseDialect
1012

1113
from .base_transform_bindings import BaseTransformBindings
1214
from .sqlite_transform_bindings import SQLiteTransformBindings
1315

16+
if TYPE_CHECKING:
17+
from pydough.sqlglot.sqlglot_relational_visitor import SQLGlotRelationalVisitor
18+
1419

1520
def bindings_from_dialect(
16-
dialect: DatabaseDialect, configs: PyDoughConfigs
21+
dialect: DatabaseDialect,
22+
configs: PyDoughConfigs,
23+
visitor: "SQLGlotRelationalVisitor",
1724
) -> BaseTransformBindings:
1825
"""
1926
Returns a binding instance corresponding to a specific database
@@ -29,8 +36,8 @@ def bindings_from_dialect(
2936
"""
3037
match dialect:
3138
case DatabaseDialect.ANSI:
32-
return BaseTransformBindings(configs)
39+
return BaseTransformBindings(configs, visitor)
3340
case DatabaseDialect.SQLITE:
34-
return SQLiteTransformBindings(configs)
41+
return SQLiteTransformBindings(configs, visitor)
3542
case _:
3643
raise NotImplementedError(f"Unsupported dialect: {dialect}")

0 commit comments

Comments
 (0)