Skip to content

Commit 70098bc

Browse files
committed
Add get_origin
1 parent 5bf52c8 commit 70098bc

File tree

5 files changed

+29
-4
lines changed

5 files changed

+29
-4
lines changed

gel/_internal/_qb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
SortAlias,
7676
VarAlias,
7777
is_pointer_descriptor,
78+
get_origin,
7879
)
7980

8081
from ._protocols import (
@@ -170,6 +171,7 @@
170171
"empty_set_if_none",
171172
"exprmethod",
172173
"get_object_type_splat",
174+
"get_origin",
173175
"is_expr_compatible",
174176
"is_pointer_descriptor",
175177
"toplevel_edgeql",

gel/_internal/_qb/_generics.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
TYPE_CHECKING,
1111
Any,
1212
TypedDict,
13+
TypeVar,
1314
NoReturn,
1415
get_args,
1516
)
@@ -375,6 +376,17 @@ def __edgeql__(self) -> tuple[type, tuple[str, dict[str, object]]]:
375376
return type_, toplevel_edgeql(self, splat_cb=splat_cb)
376377

377378

379+
_T = TypeVar("_T")
380+
381+
382+
def get_origin(x: type[_T] | BaseAlias) -> type[_T]:
383+
return (
384+
x.__gel_origin__
385+
if isinstance(x, BaseAlias) # type: ignore[return-value]
386+
else x
387+
)
388+
389+
378390
class PathAlias(BaseAlias):
379391
pass
380392

gel/_internal/_qbmodel/_abstract/_syntax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ def for_(iterator: type[_T], body: Callable[[type[_T]], type[_X]]) -> type[_X]:
2424
iter_expr = _qb.edgeql_qb_expr(iterator)
2525
scope = _qb.Scope()
2626
var = _qb.Variable(type_=iter_expr.type, scope=scope)
27-
body_ = body(_qb.AnnotatedVar(iterator.__gel_origin__, var)) # type: ignore [arg-type, attr-defined]
27+
t = _qb.get_origin(iterator)
28+
body_ = body(_qb.AnnotatedVar(t, var)) # type: ignore [arg-type]
2829
return _qb.AnnotatedExpr( # type: ignore [return-value]
29-
body_.__gel_origin__, # type: ignore [attr-defined]
30+
_qb.get_origin(body_),
3031
_qb.ForStmt(
3132
iter_expr=iter_expr,
3233
body=_qb.edgeql_qb_expr(body_),

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ module = [
126126
"gel._internal._qbmodel.*",
127127
"gel._internal._qbmodel._abstract.*",
128128
"gel._internal._qbmodel._pydantic.*",
129-
"gel.qb",
130129
"tests.test_dataclass_extras",
131130
"tests.test_dis_bool",
132131
"tests.test_is_overload",

tests/test_qb.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def test_qb_enum_01(self):
595595
self.assertEqual(e.name, "red")
596596

597597
def test_qb_for_01(self):
598-
from models.orm import std
598+
from models.orm import default, std
599599

600600
res = self.client.query(
601601
std.for_(
@@ -617,6 +617,17 @@ def test_qb_for_01(self):
617617

618618
self.assertEqual(set(res), {11, 12, 21, 22})
619619

620+
res2 = self.client.query(
621+
std.for_(
622+
default.User,
623+
lambda x: x.name,
624+
)
625+
)
626+
self.assertEqual(
627+
set(res2),
628+
{'Alice', 'Zoe', 'Billie', 'Dana', 'Cameron', 'Elsa'},
629+
)
630+
620631

621632
class TestQueryBuilderModify(tb.ModelTestCase):
622633
"""This test suite is for data manipulation using QB."""

0 commit comments

Comments
 (0)