Skip to content

Commit d93f90d

Browse files
msullivanelprans
andauthored
Add a for_ operator (#887)
Following Elvis's original branch, I added it to a `gel.qb` module. I also, on Elvis's suggestion, modified `std` to re-export everything from `gel.qb`. I renamed his `foreach` to `for_` to better match EdgeQL. Fixes #729. --------- Co-authored-by: Elvis Pranskevichus <[email protected]>
1 parent 1c191fa commit d93f90d

File tree

9 files changed

+143
-10
lines changed

9 files changed

+143
-10
lines changed

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from gel._internal._collections_extras import ImmutableChainMap
4949
from gel._internal._namespace import ident, dunder
5050
from gel._internal._qbmodel import _abstract as _qbmodel
51+
from gel._internal._qbmodel._abstract import _syntax as _qbsyntax
52+
5153
from gel._internal._reflection._enums import SchemaPart, TypeModifier
5254
from gel._internal._schemapath import SchemaPath
5355
from gel._internal._polyfills import _strenum
@@ -2273,6 +2275,15 @@ def prepare_namespace(self, mod: IntrospectedModule) -> None:
22732275
self.py_file.update_globals(
22742276
gt.name for gt in GENERIC_TYPES
22752277
)
2278+
names = _qbsyntax.__all__
2279+
self.py_file.export(*names)
2280+
self.py_file.update_globals(names)
2281+
2282+
self.write("# Re-export top-level query-builder functions")
2283+
self.write(
2284+
"from gel._internal._qbmodel._abstract._syntax "
2285+
"import *"
2286+
)
22762287

22772288
self.py_file.update_globals(
22782289
ident(t.schemapath.name)

gel/_internal/_qb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
SortAlias,
7676
VarAlias,
7777
is_pointer_descriptor,
78+
get_origin,
7879
)
7980

8081
from ._protocols import (
@@ -170,6 +171,7 @@
170171
"empty_set_if_none",
171172
"exprmethod",
172173
"get_object_type_splat",
174+
"get_origin",
173175
"is_expr_compatible",
174176
"is_pointer_descriptor",
175177
"toplevel_edgeql",

gel/_internal/_qb/_expressions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def __edgeql_expr__(self, *, ctx: ScopeContext) -> str:
457457
rexpr=item,
458458
type_=SchemaPath("std", "bool"),
459459
)
460-
return f"FILTER {edgeql(fexpr, ctx=ctx)}"
460+
return f"FILTER {edgeql_exprstmt(fexpr, ctx=ctx)}"
461461

462462

463463
class OrderDirection(_strenum.StrEnum):
@@ -529,7 +529,7 @@ def __edgeql_expr__(self, *, ctx: ScopeContext) -> str:
529529
type_=SchemaPath("std", "bool"),
530530
)
531531

532-
return f"ORDER BY {edgeql(dexpr, ctx=ctx)}"
532+
return f"ORDER BY {edgeql_exprstmt(dexpr, ctx=ctx)}"
533533

534534

535535
@dataclass(kw_only=True, frozen=True)
@@ -548,7 +548,7 @@ def __edgeql_expr__(self, *, ctx: ScopeContext) -> str:
548548
if isinstance(self.limit, IntLiteral) and self.limit.val == 1:
549549
return "LIMIT 1"
550550
else:
551-
clause = edgeql(self.limit, ctx=ctx)
551+
clause = edgeql_exprstmt(self.limit, ctx=ctx)
552552
return f"LIMIT {clause}"
553553

554554

@@ -564,7 +564,7 @@ def precedence(self) -> _edgeql.Precedence:
564564
return _edgeql.PRECEDENCE[_edgeql.Token.OFFSET]
565565

566566
def __edgeql_expr__(self, *, ctx: ScopeContext) -> str:
567-
return f"OFFSET {edgeql(self.offset, ctx=ctx)}"
567+
return f"OFFSET {edgeql_exprstmt(self.offset, ctx=ctx)}"
568568

569569

570570
@dataclass(kw_only=True, frozen=True)
@@ -726,7 +726,7 @@ def _body_edgeql(self, ctx: ScopeContext) -> str:
726726

727727

728728
@dataclass(kw_only=True, frozen=True)
729-
class ForStmt(IteratorExpr):
729+
class ForStmt(IteratorExpr, Stmt):
730730
stmt: _edgeql.Token = _edgeql.Token.FOR
731731
iter_expr: Expr
732732
body: Expr
@@ -745,12 +745,19 @@ def subnodes(self) -> Iterable[Node]:
745745
return (self.iter_expr, self.body)
746746

747747
def _edgeql(self, ctx: ScopeContext) -> str:
748+
ctx.bind(self.var)
748749
return (
749750
f"{self.stmt} {edgeql(self.var, ctx=ctx)} IN "
750751
f"({edgeql(self.iter_expr, ctx=ctx)})\n"
751752
f"UNION ({edgeql(self.body, ctx=ctx)})"
752753
)
753754

755+
def _iteration_edgeql(self, ctx: ScopeContext) -> str:
756+
raise AssertionError('...')
757+
758+
def _body_edgeql(self, ctx: ScopeContext) -> str:
759+
raise AssertionError('...')
760+
754761

755762
class Splat(_strenum.StrEnum):
756763
STAR = "*"
@@ -908,6 +915,17 @@ def get_object_type_splat(cls: type[GelTypeMetadata]) -> Shape:
908915
return shape
909916

910917

918+
def edgeql_exprstmt(
919+
source: ExprCompatible,
920+
*,
921+
ctx: ScopeContext,
922+
) -> str:
923+
res = edgeql(source, ctx=ctx)
924+
if isinstance(source, Stmt):
925+
res = f'({res})'
926+
return res
927+
928+
911929
def toplevel_edgeql(
912930
x: ExprCompatible,
913931
*,

gel/_internal/_qb/_generics.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TYPE_CHECKING,
1111
Any,
1212
TypedDict,
13+
TypeVar,
1314
NoReturn,
1415
get_args,
1516
)
@@ -315,9 +316,6 @@ def __infix_op__(
315316
*,
316317
swapped: bool = False,
317318
) -> Any:
318-
if op == "__eq__" and operand is self:
319-
return True
320-
321319
# Check for None comparison and raise appropriate error
322320
if operand is None and op in {"__eq__", "__ne__"}:
323321
_raise_op_error(_Op.IS_NONE if op == "__eq__" else _Op.IS_NOT_NONE)
@@ -375,6 +373,17 @@ def __edgeql__(self) -> tuple[type, tuple[str, dict[str, object]]]:
375373
return type_, toplevel_edgeql(self, splat_cb=splat_cb)
376374

377375

376+
_T = TypeVar("_T")
377+
378+
379+
def get_origin(x: type[_T] | BaseAlias) -> type[_T]:
380+
return (
381+
x.__gel_origin__
382+
if isinstance(x, BaseAlias) # type: ignore[return-value]
383+
else x
384+
)
385+
386+
378387
class PathAlias(BaseAlias):
379388
pass
380389

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-PackageName: gel-python
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors.
4+
5+
6+
"""Query building constructs"""
7+
8+
from typing import TypeVar
9+
from collections.abc import Callable
10+
11+
from gel._internal import _qb
12+
from gel._internal._qbmodel import _abstract
13+
14+
15+
_T = TypeVar("_T", bound=_abstract.GelType)
16+
_X = TypeVar("_X", bound=_abstract.GelType)
17+
18+
19+
def for_(iterator: type[_T], body: Callable[[type[_T]], type[_X]]) -> type[_X]:
20+
"""Evaluate the expression returned by *body* for each element in *iter*.
21+
22+
This is the Pythonic representation of the EdgeQL FOR expression."""
23+
24+
iter_expr = _qb.edgeql_qb_expr(iterator)
25+
scope = _qb.Scope()
26+
var = _qb.Variable(type_=iter_expr.type, scope=scope)
27+
t = _qb.get_origin(iterator)
28+
body_ = body(_qb.AnnotatedVar(t, var)) # type: ignore [arg-type]
29+
return _qb.AnnotatedExpr( # type: ignore [return-value]
30+
_qb.get_origin(body_),
31+
_qb.ForStmt(
32+
iter_expr=iter_expr,
33+
body=_qb.edgeql_qb_expr(body_),
34+
scope=scope,
35+
),
36+
)
37+
38+
39+
__all__ = ("for_",)

gel/_internal/ruff.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ extend-select = [
2424
"PIE", # flake8-pie
2525
"PL", # pylint
2626
"PYI", # flake8-pyi
27-
"Q", # flake8-quotes
2827
"RUF", # ruff specific
2928
"S", # flake8-bandit
3029
"SIM", # flake8-simplify

gel/models/ruff.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ extend-select = [
2424
"PIE", # flake8-pie
2525
"PL", # pylint
2626
"PYI", # flake8-pyi
27-
"Q", # flake8-quotes
2827
"RUF", # ruff specific
2928
"S", # flake8-bandit
3029
"SIM", # flake8-simplify

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ extend-ignore = [
221221
"E402", # module-import-not-at-top-of-file
222222
"E252", # missing-whitespace-around-parameter-equals
223223
"F541", # f-string-missing-placeholders
224+
"Q000", # prefer double quotes
224225
]
225226

226227
[tool.ruff.format]

tests/test_qb.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,61 @@ def test_qb_enum_01(self):
632632
self.assertEqual(e.color, default.Color.Red)
633633
self.assertEqual(e.name, "red")
634634

635+
def test_qb_for_01(self):
636+
from models.orm import default, std
637+
638+
res = self.client.query(
639+
std.for_(
640+
std.range_unpack(std.range(std.int64(1), std.int64(10))),
641+
lambda x: x * 2,
642+
)
643+
)
644+
self.assertEqual(set(res), {i * 2 for i in range(1, 10)})
645+
646+
res = self.client.query(
647+
std.for_(
648+
std.range_unpack(std.range(std.int64(1), std.int64(3))),
649+
lambda x: std.for_(
650+
std.range_unpack(std.range(std.int64(1), std.int64(3))),
651+
lambda y: x * 10 + y,
652+
)
653+
)
654+
)
655+
656+
self.assertEqual(set(res), {11, 12, 21, 22})
657+
658+
res2 = self.client.query(
659+
std.for_(
660+
default.User,
661+
lambda x: x.name,
662+
)
663+
)
664+
self.assertEqual(
665+
set(res2),
666+
{'Alice', 'Zoe', 'Billie', 'Dana', 'Cameron', 'Elsa'},
667+
)
668+
669+
res3 = self.client.query(
670+
default.User.filter(
671+
lambda u: std.for_(
672+
std.assert_exists(std.int64(0)),
673+
# HMMMMM
674+
lambda x: x == x,
675+
)
676+
)
677+
)
678+
self.assertEqual(len(res3), 6)
679+
680+
res4 = self.client.query(
681+
default.User.filter(
682+
lambda u: std.for_(
683+
u.name,
684+
lambda x: x == "Alice",
685+
)
686+
)
687+
)
688+
self.assertEqual(len(res4), 1)
689+
635690

636691
class TestQueryBuilderModify(tb.ModelTestCase):
637692
"""This test suite is for data manipulation using QB."""

0 commit comments

Comments
 (0)