Skip to content

Commit b0c37ae

Browse files
authored
feat: add support for lateral joins in the builder (#72)
Introduce methods for creating LATERAL, LEFT LATERAL, and CROSS LATERAL joins in the SQL builder, along with comprehensive tests to validate their functionality.
1 parent 81bd325 commit b0c37ae

File tree

5 files changed

+629
-202
lines changed

5 files changed

+629
-202
lines changed

sqlspec/_sql.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,42 @@ def cross_join_(self) -> "JoinBuilder":
628628
"""Create a CROSS JOIN builder."""
629629
return JoinBuilder("cross join")
630630

631+
@property
632+
def lateral_join_(self) -> "JoinBuilder":
633+
"""Create a LATERAL JOIN builder.
634+
635+
Returns:
636+
JoinBuilder configured for LATERAL JOIN
637+
638+
Example:
639+
```python
640+
query = (
641+
sql.select("u.name", "arr.value")
642+
.from_("users u")
643+
.join(sql.lateral_join_("UNNEST(u.tags)").on("true"))
644+
)
645+
```
646+
"""
647+
return JoinBuilder("lateral join", lateral=True)
648+
649+
@property
650+
def left_lateral_join_(self) -> "JoinBuilder":
651+
"""Create a LEFT LATERAL JOIN builder.
652+
653+
Returns:
654+
JoinBuilder configured for LEFT LATERAL JOIN
655+
"""
656+
return JoinBuilder("left join", lateral=True)
657+
658+
@property
659+
def cross_lateral_join_(self) -> "JoinBuilder":
660+
"""Create a CROSS LATERAL JOIN builder.
661+
662+
Returns:
663+
JoinBuilder configured for CROSS LATERAL JOIN
664+
"""
665+
return JoinBuilder("cross join", lateral=True)
666+
631667
def __getattr__(self, name: str) -> "Column":
632668
"""Dynamically create column references.
633669

sqlspec/builder/mixins/_join_operations.py

Lines changed: 205 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from sqlspec.utils.type_guards import has_query_builder_parameters
1515

1616
if TYPE_CHECKING:
17-
from sqlspec.builder._column import ColumnExpression
1817
from sqlspec.core.statement import SQL
1918
from sqlspec.protocols import SQLBuilderProtocol
2019

@@ -36,74 +35,133 @@ def join(
3635
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
3736
alias: Optional[str] = None,
3837
join_type: str = "INNER",
38+
lateral: bool = False,
3939
) -> Self:
4040
builder = cast("SQLBuilderProtocol", self)
41+
self._validate_join_context(builder)
42+
43+
# Handle Join expressions directly (from JoinBuilder.on() calls)
44+
if isinstance(table, exp.Join):
45+
if builder._expression is not None and isinstance(builder._expression, exp.Select):
46+
builder._expression = builder._expression.join(table, copy=False)
47+
return cast("Self", builder)
48+
49+
table_expr = self._parse_table_expression(table, alias, builder)
50+
on_expr = self._parse_on_condition(on, builder)
51+
join_expr = self._create_join_expression(table_expr, on_expr, join_type)
52+
53+
if lateral:
54+
self._apply_lateral_modifier(join_expr)
55+
56+
if builder._expression is not None and isinstance(builder._expression, exp.Select):
57+
builder._expression = builder._expression.join(join_expr, copy=False)
58+
return cast("Self", builder)
59+
60+
def _validate_join_context(self, builder: "SQLBuilderProtocol") -> None:
61+
"""Validate that the join can be applied to the current expression."""
4162
if builder._expression is None:
4263
builder._expression = exp.Select()
4364
if not isinstance(builder._expression, exp.Select):
4465
msg = "JOIN clause is only supported for SELECT statements."
4566
raise SQLBuilderError(msg)
46-
table_expr: exp.Expression
67+
68+
def _parse_table_expression(
69+
self, table: Union[str, exp.Expression, Any], alias: Optional[str], builder: "SQLBuilderProtocol"
70+
) -> exp.Expression:
71+
"""Parse table parameter into a SQLGlot expression."""
4772
if isinstance(table, str):
48-
table_expr = parse_table_expression(table, alias)
49-
elif has_query_builder_parameters(table):
50-
if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
51-
table_expr_value = getattr(table, "_expression", None)
52-
if table_expr_value is not None:
53-
subquery_exp = exp.paren(table_expr_value)
54-
else:
55-
subquery_exp = exp.paren(exp.Anonymous(this=""))
56-
table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
73+
return parse_table_expression(table, alias)
74+
if has_query_builder_parameters(table):
75+
return self._handle_query_builder_table(table, alias, builder)
76+
if isinstance(table, exp.Expression):
77+
return table
78+
return cast("exp.Expression", table)
79+
80+
def _handle_query_builder_table(
81+
self, table: Any, alias: Optional[str], builder: "SQLBuilderProtocol"
82+
) -> exp.Expression:
83+
"""Handle table parameters that are query builders."""
84+
if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
85+
table_expr_value = getattr(table, "_expression", None)
86+
if table_expr_value is not None:
87+
subquery_exp = exp.paren(table_expr_value)
5788
else:
58-
subquery = table.build()
59-
sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
60-
subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
61-
table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
62-
else:
63-
table_expr = table
64-
on_expr: Optional[exp.Expression] = None
65-
if on is not None:
66-
if isinstance(on, str):
67-
on_expr = exp.condition(on)
68-
elif hasattr(on, "expression") and hasattr(on, "sql"):
69-
# Handle SQL objects (from sql.raw with parameters)
70-
expression = getattr(on, "expression", None)
71-
if expression is not None and isinstance(expression, exp.Expression):
72-
# Merge parameters from SQL object into builder
73-
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
74-
sql_parameters = getattr(on, "parameters", {})
75-
for param_name, param_value in sql_parameters.items():
76-
builder.add_parameter(param_value, name=param_name)
77-
on_expr = expression
78-
else:
79-
# If expression is None, fall back to parsing the raw SQL
80-
sql_text = getattr(on, "sql", "")
81-
# Merge parameters even when parsing raw SQL
82-
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
83-
sql_parameters = getattr(on, "parameters", {})
84-
for param_name, param_value in sql_parameters.items():
85-
builder.add_parameter(param_value, name=param_name)
86-
on_expr = exp.maybe_parse(sql_text) or exp.condition(str(sql_text))
87-
# For other types (should be exp.Expression)
88-
elif isinstance(on, exp.Expression):
89-
on_expr = on
90-
else:
91-
# Last resort - convert to string and parse
92-
on_expr = exp.condition(str(on))
89+
subquery_exp = exp.paren(exp.Anonymous(this=""))
90+
return exp.alias_(subquery_exp, alias) if alias else subquery_exp
91+
subquery = table.build()
92+
sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
93+
subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
94+
return exp.alias_(subquery_exp, alias) if alias else subquery_exp
95+
96+
def _parse_on_condition(
97+
self, on: Optional[Union[str, exp.Expression, "SQL"]], builder: "SQLBuilderProtocol"
98+
) -> Optional[exp.Expression]:
99+
"""Parse ON condition into a SQLGlot expression."""
100+
if on is None:
101+
return None
102+
103+
if isinstance(on, str):
104+
return exp.condition(on)
105+
if hasattr(on, "expression") and hasattr(on, "sql"):
106+
return self._handle_sql_object_condition(on, builder)
107+
if isinstance(on, exp.Expression):
108+
return on
109+
# Last resort - convert to string and parse
110+
return exp.condition(str(on))
111+
112+
def _handle_sql_object_condition(self, on: Any, builder: "SQLBuilderProtocol") -> exp.Expression:
113+
"""Handle SQL object conditions with parameter binding."""
114+
expression = getattr(on, "expression", None)
115+
if expression is not None and isinstance(expression, exp.Expression):
116+
# Merge parameters from SQL object into builder
117+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
118+
sql_parameters = getattr(on, "parameters", {})
119+
for param_name, param_value in sql_parameters.items():
120+
builder.add_parameter(param_value, name=param_name)
121+
return cast("exp.Expression", expression)
122+
# If expression is None, fall back to parsing the raw SQL
123+
sql_text = getattr(on, "sql", "")
124+
# Merge parameters even when parsing raw SQL
125+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
126+
sql_parameters = getattr(on, "parameters", {})
127+
for param_name, param_value in sql_parameters.items():
128+
builder.add_parameter(param_value, name=param_name)
129+
parsed_expr = exp.maybe_parse(sql_text)
130+
return parsed_expr if parsed_expr is not None else exp.condition(str(sql_text))
131+
132+
def _create_join_expression(
133+
self, table_expr: exp.Expression, on_expr: Optional[exp.Expression], join_type: str
134+
) -> exp.Join:
135+
"""Create the appropriate JOIN expression based on join type."""
93136
join_type_upper = join_type.upper()
94137
if join_type_upper == "INNER":
95-
join_expr = exp.Join(this=table_expr, on=on_expr)
96-
elif join_type_upper == "LEFT":
97-
join_expr = exp.Join(this=table_expr, on=on_expr, side="LEFT")
98-
elif join_type_upper == "RIGHT":
99-
join_expr = exp.Join(this=table_expr, on=on_expr, side="RIGHT")
100-
elif join_type_upper == "FULL":
101-
join_expr = exp.Join(this=table_expr, on=on_expr, side="FULL", kind="OUTER")
138+
return exp.Join(this=table_expr, on=on_expr)
139+
if join_type_upper == "LEFT":
140+
return exp.Join(this=table_expr, on=on_expr, side="LEFT")
141+
if join_type_upper == "RIGHT":
142+
return exp.Join(this=table_expr, on=on_expr, side="RIGHT")
143+
if join_type_upper == "FULL":
144+
return exp.Join(this=table_expr, on=on_expr, side="FULL", kind="OUTER")
145+
if join_type_upper == "CROSS":
146+
return exp.Join(this=table_expr, kind="CROSS")
147+
msg = f"Unsupported join type: {join_type}"
148+
raise SQLBuilderError(msg)
149+
150+
def _apply_lateral_modifier(self, join_expr: exp.Join) -> None:
151+
"""Apply LATERAL modifier to the join expression."""
152+
current_kind = join_expr.args.get("kind")
153+
current_side = join_expr.args.get("side")
154+
155+
if current_kind == "CROSS":
156+
join_expr.set("kind", "CROSS LATERAL")
157+
elif current_kind == "OUTER" and current_side == "FULL":
158+
join_expr.set("side", "FULL") # Keep side
159+
join_expr.set("kind", "OUTER LATERAL")
160+
elif current_side:
161+
join_expr.set("kind", f"{current_side} LATERAL")
162+
join_expr.set("side", None) # Clear side to avoid duplication
102163
else:
103-
msg = f"Unsupported join type: {join_type}"
104-
raise SQLBuilderError(msg)
105-
builder._expression = builder._expression.join(join_expr, copy=False)
106-
return cast("Self", builder)
164+
join_expr.set("kind", "LATERAL")
107165

108166
def inner_join(
109167
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
@@ -154,6 +212,63 @@ def cross_join(self, table: Union[str, exp.Expression, Any], alias: Optional[str
154212
builder._expression = builder._expression.join(join_expr, copy=False)
155213
return cast("Self", builder)
156214

215+
def lateral_join(
216+
self,
217+
table: Union[str, exp.Expression, Any],
218+
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
219+
alias: Optional[str] = None,
220+
) -> Self:
221+
"""Create a LATERAL JOIN.
222+
223+
Args:
224+
table: Table, subquery, or table function to join
225+
on: Optional join condition (for LATERAL JOINs with ON clause)
226+
alias: Optional alias for the joined table/subquery
227+
228+
Returns:
229+
Self for method chaining
230+
231+
Example:
232+
```python
233+
query = (
234+
sql.select("u.name", "arr.value")
235+
.from_("users u")
236+
.lateral_join("UNNEST(u.tags)", alias="arr")
237+
)
238+
```
239+
"""
240+
return self.join(table, on=on, alias=alias, join_type="INNER", lateral=True)
241+
242+
def left_lateral_join(
243+
self,
244+
table: Union[str, exp.Expression, Any],
245+
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
246+
alias: Optional[str] = None,
247+
) -> Self:
248+
"""Create a LEFT LATERAL JOIN.
249+
250+
Args:
251+
table: Table, subquery, or table function to join
252+
on: Optional join condition
253+
alias: Optional alias for the joined table/subquery
254+
255+
Returns:
256+
Self for method chaining
257+
"""
258+
return self.join(table, on=on, alias=alias, join_type="LEFT", lateral=True)
259+
260+
def cross_lateral_join(self, table: Union[str, exp.Expression, Any], alias: Optional[str] = None) -> Self:
261+
"""Create a CROSS LATERAL JOIN (no ON condition).
262+
263+
Args:
264+
table: Table, subquery, or table function to join
265+
alias: Optional alias for the joined table/subquery
266+
267+
Returns:
268+
Self for method chaining
269+
"""
270+
return self.join(table, on=None, alias=alias, join_type="CROSS", lateral=True)
271+
157272

158273
@trait
159274
class JoinBuilder:
@@ -181,32 +296,19 @@ class JoinBuilder:
181296
```
182297
"""
183298

184-
def __init__(self, join_type: str) -> None:
299+
def __init__(self, join_type: str, lateral: bool = False) -> None:
185300
"""Initialize the join builder.
186301
187302
Args:
188-
join_type: Type of join (inner, left, right, full, cross)
303+
join_type: Type of join (inner, left, right, full, cross, lateral)
304+
lateral: Whether this is a LATERAL join
189305
"""
190306
self._join_type = join_type.upper()
307+
self._lateral = lateral
191308
self._table: Optional[Union[str, exp.Expression]] = None
192309
self._condition: Optional[exp.Expression] = None
193310
self._alias: Optional[str] = None
194311

195-
def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
196-
"""Equal to (==) - not typically used but needed for type consistency."""
197-
from sqlspec.builder._column import ColumnExpression
198-
199-
# JoinBuilder doesn't have a direct expression, so this is a placeholder
200-
# In practice, this shouldn't be called as joins are used differently
201-
placeholder_expr = exp.Literal.string(f"join_{self._join_type.lower()}")
202-
if other is None:
203-
return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
204-
return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
205-
206-
def __hash__(self) -> int:
207-
"""Make JoinBuilder hashable."""
208-
return hash(id(self))
209-
210312
def __call__(self, table: Union[str, exp.Expression], alias: Optional[str] = None) -> Self:
211313
"""Set the table to join.
212314
@@ -254,15 +356,33 @@ def on(self, condition: Union[str, exp.Expression]) -> exp.Expression:
254356
table_expr = exp.alias_(table_expr, self._alias)
255357

256358
# Create the appropriate join type using same pattern as existing JoinClauseMixin
257-
if self._join_type == "INNER JOIN":
258-
return exp.Join(this=table_expr, on=condition_expr)
259-
if self._join_type == "LEFT JOIN":
260-
return exp.Join(this=table_expr, on=condition_expr, side="LEFT")
261-
if self._join_type == "RIGHT JOIN":
262-
return exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
263-
if self._join_type == "FULL JOIN":
264-
return exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
265-
if self._join_type == "CROSS JOIN":
359+
if self._join_type in {"INNER JOIN", "INNER", "LATERAL JOIN"}:
360+
join_expr = exp.Join(this=table_expr, on=condition_expr)
361+
elif self._join_type in {"LEFT JOIN", "LEFT"}:
362+
join_expr = exp.Join(this=table_expr, on=condition_expr, side="LEFT")
363+
elif self._join_type in {"RIGHT JOIN", "RIGHT"}:
364+
join_expr = exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
365+
elif self._join_type in {"FULL JOIN", "FULL"}:
366+
join_expr = exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
367+
elif self._join_type in {"CROSS JOIN", "CROSS"}:
266368
# CROSS JOIN doesn't use ON condition
267-
return exp.Join(this=table_expr, kind="CROSS")
268-
return exp.Join(this=table_expr, on=condition_expr)
369+
join_expr = exp.Join(this=table_expr, kind="CROSS")
370+
else:
371+
join_expr = exp.Join(this=table_expr, on=condition_expr)
372+
373+
if self._lateral or self._join_type == "LATERAL JOIN":
374+
current_kind = join_expr.args.get("kind")
375+
current_side = join_expr.args.get("side")
376+
377+
if current_kind == "CROSS":
378+
join_expr.set("kind", "CROSS LATERAL")
379+
elif current_kind == "OUTER" and current_side == "FULL":
380+
join_expr.set("side", "FULL") # Keep side
381+
join_expr.set("kind", "OUTER LATERAL")
382+
elif current_side:
383+
join_expr.set("kind", f"{current_side} LATERAL")
384+
join_expr.set("side", None) # Clear side to avoid duplication
385+
else:
386+
join_expr.set("kind", "LATERAL")
387+
388+
return join_expr

0 commit comments

Comments
 (0)