diff --git a/gel/_internal/_codegen/_models/_pydantic.py b/gel/_internal/_codegen/_models/_pydantic.py index de417203..58a69891 100644 --- a/gel/_internal/_codegen/_models/_pydantic.py +++ b/gel/_internal/_codegen/_models/_pydantic.py @@ -435,6 +435,12 @@ class ModuleAspect(enum.Enum): LATE = enum.auto() +@dataclasses.dataclass(frozen=True, kw_only=True) +class Backlink: + source: reflection.ObjectType + pointer: reflection.Pointer + + class SchemaGenerator: def __init__( self, @@ -451,6 +457,9 @@ def __init__( self._std_modules: list[SchemaPath] = [] self._types: Mapping[str, reflection.Type] = {} self._casts: reflection.CastMatrix + self._backlinks: dict[ + reflection.ObjectType, dict[str, list[Backlink]] + ] = {} self._operators: reflection.OperatorMatrix self._functions: list[reflection.Function] self._globals: list[reflection.Global] @@ -579,6 +588,7 @@ def run(self, outdir: pathlib.Path) -> tuple[Schema, set[pathlib.Path]]: all_casts=self._casts, all_operators=self._operators, all_globals=self._globals, + all_backlinks=self._backlinks, modules=self._modules, schema_part=self._schema_part, ) @@ -604,6 +614,7 @@ def run(self, outdir: pathlib.Path) -> tuple[Schema, set[pathlib.Path]]: all_casts=self._casts, all_operators=self._operators, all_globals=self._globals, + all_backlinks=self._backlinks, modules=all_modules, schema_part=self._schema_part, ) @@ -659,6 +670,24 @@ def introspect_schema(self) -> Schema: if reflection.is_object_type(t): name = t.schemapath self._modules[name.parent]["object_types"][name.name] = t + + for p in t.pointers: + if ( + reflection.is_link(p) + # For now don't include std::BaseObject.__type__ + # Users should just select on the appropriate type + and p.name != "__type__" + ): + target = self._types[p.target_id] + assert isinstance(target, reflection.ObjectType) + if target not in self._backlinks: + self._backlinks[target] = {} + if p.name not in self._backlinks[target]: + self._backlinks[target][p.name] = [] + self._backlinks[target][p.name].append( + Backlink(source=t, pointer=p) + ) + elif reflection.is_scalar_type(t): name = t.schemapath self._modules[name.parent]["scalar_types"][name.name] = t @@ -694,6 +723,7 @@ def _generate_common_types( all_casts=self._casts, all_operators=self._operators, all_globals=self._globals, + all_backlinks=self._backlinks, modules=self._modules, schema_part=self._schema_part, ) @@ -944,6 +974,9 @@ def __init__( all_casts: reflection.CastMatrix, all_operators: reflection.OperatorMatrix, all_globals: list[reflection.Global], + all_backlinks: Mapping[ + reflection.ObjectType, Mapping[str, Sequence[Backlink]] + ], modules: Collection[SchemaPath], schema_part: reflection.SchemaPart, ) -> None: @@ -954,6 +987,7 @@ def __init__( self._casts = all_casts self._operators = all_operators self._globals = all_globals + self._backlinks = all_backlinks schema_obj_type = None for t in all_types.values(): self._types_by_name[t.name] = t @@ -2340,7 +2374,7 @@ def write_generic_types( unpack = self.import_name("typing_extensions", "Unpack") geltype = self.import_name(BASE_IMPL, "GelType") geltypemeta = self.import_name(BASE_IMPL, "GelTypeMeta") - gelmodel = self.import_name(BASE_IMPL, "GelModel") + gelobjectmodel = self.import_name(BASE_IMPL, "GelObjectModel") gelmodelmeta = self.import_name(BASE_IMPL, "GelModelMeta") anytuple = self.import_name(BASE_IMPL, "AnyTuple") anynamedtuple = self.import_name(BASE_IMPL, "AnyNamedTuple") @@ -2367,7 +2401,7 @@ def write_generic_types( geltype, ], SchemaPath("std", "anyobject"): [ - gelmodel, + gelobjectmodel, "anytype", ], SchemaPath("std", "anytuple"): [ @@ -4086,6 +4120,9 @@ def _mangle_default_shape(name: str) -> str: if proplinks: self.write_object_type_link_models(objtype) + if objtype.name != "std::FreeObject": + self._write_object_backlinks(objtype) + anyobject_meta = self.get_object( SchemaPath("std", "__anyobject_meta__"), aspect=ModuleAspect.SHAPES, @@ -4213,6 +4250,16 @@ def write_id_computed( self.write(f"__gel_type_class__ = __{name}_ops__") if objtype.name == "std::BaseObject": write_id_attr(objtype, "RequiredId") + + if objtype.name != "std::FreeObject": + backlinks_model_name = self._mangle_backlinks_model_name(name) + g_oblm_desc = self.import_name( + BASE_IMPL, "GelObjectBacklinksModelDescriptor" + ) + oblm_desc = f"{g_oblm_desc}[{backlinks_model_name}]" + self.write(f"__backlinks__: {oblm_desc} = {oblm_desc}()") + self.write() + self._write_base_object_type_body(objtype, include_tname=True) with self.type_checking(): self._write_object_type_qb_methods(objtype) @@ -4333,6 +4380,208 @@ def write_id_computed( self.write() + @staticmethod + def _mangle_backlinks_model_name(name: str) -> str: + return f"__{name}_backlinks__" + + def _write_object_backlinks( + self, + objtype: reflection.ObjectType, + ) -> None: + type_name = objtype.schemapath + name = type_name.name + + schema_path = self.import_name(BASE_IMPL, "SchemaPath") + parametric_type_name = self.import_name( + BASE_IMPL, "ParametricTypeName" + ) + computed_multi_link = self.import_name(BASE_IMPL, "ComputedMultiLink") + std_base_object_t = self.get_object( + SchemaPath.from_segments("std", "BaseObject"), + aspect=ModuleAspect.SHAPES, + ) + + objtype_bases = [ + base_type + for base_ref in objtype.bases + if (base_type := self._types.get(base_ref.id, None)) + if isinstance(base_type, reflection.ObjectType) + ] + objtype_name = type_name.as_python_code( + schema_path, parametric_type_name + ) + + backlinks_model_name = self._mangle_backlinks_model_name(name) + + backlinks_class_bases: list[str] + backlinks_reflection_class_bases: list[str] + + # Backlinks' reflections' pointers combine the current types + # backlinks with those of their base types. + # + # Eg. With `type A` and `type B extending A`, + # - __A_backlinks__.__gel_reflection__.pointers will contain + # all backlinks to A (and BaseObject) + # - __B_backlinks__.__gel_reflection__.pointers will contain + # all backlinks to B, backlinks to A. + # + # No base type backlinks are needed for BaseObject. + backlinks_reflection_pointer_bases: list[str] + + if objtype.name == "std::BaseObject": + # __BaseObject_backlinks__ derives from GelObjectBacklinksModel + # while all other backlinks derive from it directly on indirectly. + object_backlinks_model = self.import_name( + BASE_IMPL, "GelObjectBacklinksModel" + ) + backlinks_class_bases = [object_backlinks_model] + backlinks_reflection_class_bases = [ + f"{object_backlinks_model}.__gel_reflection__" + ] + backlinks_reflection_pointer_bases = [] + else: + backlinks_class_bases = [ + self.get_object( + SchemaPath( + base_type.schemapath.parent, + self._mangle_backlinks_model_name( + base_type.schemapath.name + ), + ), + aspect=ModuleAspect.SHAPES, + ) + for base_type in objtype_bases + ] + backlinks_reflection_class_bases = [ + f"{bbt}.__gel_reflection__" for bbt in backlinks_class_bases + ] + backlinks_reflection_pointer_bases = backlinks_class_bases + + object_backlinks = self._backlinks.get(objtype, {}) + + # The backlinks model class + with self._class_def(backlinks_model_name, backlinks_class_bases): + with self._class_def( + "__gel_reflection__", backlinks_reflection_class_bases + ): + self.write(f"name = {objtype_name}") + self.write(f"type_name = {objtype_name}") + self._write_backlinks_pointers_reflection( + object_backlinks, backlinks_reflection_pointer_bases + ) + + for backlink_name in object_backlinks: + backlink_t = f"{computed_multi_link}[{std_base_object_t}]" + self.write(f"{backlink_name}: {backlink_t}") + + self.export(backlinks_model_name) + + self.write() + + def _write_backlinks_pointers_reflection( + self, + object_backlinks: Mapping[str, Sequence[Backlink]], + backlinks_reflection_pointer_bases: Sequence[str], + ) -> None: + dict_ = self.import_name( + "builtins", "dict", import_time=ImportTime.typecheck + ) + str_ = self.import_name( + "builtins", "str", import_time=ImportTime.typecheck + ) + gel_ptr_ref = self.import_name( + BASE_IMPL, + "GelPointerReflection", + import_time=ImportTime.runtime + if object_backlinks + else ImportTime.typecheck, + ) + lazyclassproperty = self.import_name(BASE_IMPL, "LazyClassProperty") + ptr_ref_t = f"{dict_}[{str_}, {gel_ptr_ref}]" + with self._classmethod_def( + "pointers", + [], + ptr_ref_t, + decorators=(f'{lazyclassproperty}["{ptr_ref_t}"]',), + ): + if object_backlinks: + self.write(f"my_ptrs: {ptr_ref_t} = {{") + classes = { + "SchemaPath": self.import_name(BASE_IMPL, "SchemaPath"), + "ParametricTypeName": self.import_name( + BASE_IMPL, "ParametricTypeName" + ), + "GelPointerReflection": gel_ptr_ref, + "Cardinality": self.import_name(BASE_IMPL, "Cardinality"), + "PointerKind": self.import_name(BASE_IMPL, "PointerKind"), + "StdBaseObject": self.get_object( + SchemaPath.from_segments("std", "BaseObject"), + aspect=ModuleAspect.SHAPES, + ), + } + with self.indented(): + for ( + backlink_name, + backlink_values, + ) in object_backlinks.items(): + r = self._reflect_backlink( + backlink_name, backlink_values, classes + ) + self.write(f"{backlink_name!r}: {r},") + self.write("}") + else: + self.write(f"my_ptrs: {ptr_ref_t} = {{}}") + + if backlinks_reflection_pointer_bases: + pp = "__gel_reflection__.pointers" + ret = self.format_list( + "return ({list})", + [ + "my_ptrs", + *_map_name( + lambda s: f"{s}.{pp}", + backlinks_reflection_pointer_bases, + ), + ], + separator=" | ", + carry_separator=True, + ) + else: + ret = "return my_ptrs" + + self.write(ret) + + self.write() + + def _reflect_backlink( + self, + name: str, + backlinks: Sequence[Backlink], + classes: dict[str, str], + ) -> str: + kwargs: dict[str, str] = { + "name": repr(name), + "type": classes["StdBaseObject"], + "kind": ( + f"{classes['PointerKind']}({str(reflection.PointerKind.Link)!r})" + ), + "cardinality": ( + f"{classes['Cardinality']}({str(reflection.Cardinality.Many)!r})" + ), + "computed": "True", + "readonly": "True", + "has_default": "False", + "mutable": "False", + } + + # For now don't get any back link props + kwargs["properties"] = "None" + + return self.format_list( + f"{classes['GelPointerReflection']}({{list}})", + [f"{k}={v}" for k, v in kwargs.items()], + ) + @contextlib.contextmanager def _object_type_variant( self, @@ -4354,8 +4603,8 @@ def _object_type_variant( ) if not list(variant_bases): - gel_model = self.import_name(BASE_IMPL, "GelModel") - bases.append(gel_model) + gel_object_model = self.import_name(BASE_IMPL, "GelObjectModel") + bases.append(gel_object_model) with self._class_def( variant, diff --git a/gel/_internal/_qb/_abstract.py b/gel/_internal/_qb/_abstract.py index 749e1882..2c25ce6c 100644 --- a/gel/_internal/_qb/_abstract.py +++ b/gel/_internal/_qb/_abstract.py @@ -156,6 +156,7 @@ class PathExpr(AtomicExpr): name: str is_lprop: bool = False is_link: bool = False + is_backlink: bool = False def subnodes(self) -> Iterable[Node]: return (self.source,) diff --git a/gel/_internal/_qb/_expressions.py b/gel/_internal/_qb/_expressions.py index 23801473..6ba61f25 100644 --- a/gel/_internal/_qb/_expressions.py +++ b/gel/_internal/_qb/_expressions.py @@ -217,6 +217,8 @@ def __edgeql_expr__(self, *, ctx: ScopeContext) -> str: source = current.source if isinstance(source, PathPrefix) and source.lprop_pivot: step = f"@{_edgeql.quote_ident(current.name)}" + elif current.is_backlink: + step = f".<{_edgeql.quote_ident(current.name)}" else: step = f".{_edgeql.quote_ident(current.name)}" steps.append(step) diff --git a/gel/_internal/_qbmodel/_abstract/__init__.py b/gel/_internal/_qbmodel/_abstract/__init__.py index 6d54bc7d..41fb17a8 100644 --- a/gel/_internal/_qbmodel/_abstract/__init__.py +++ b/gel/_internal/_qbmodel/_abstract/__init__.py @@ -11,6 +11,7 @@ AbstractGelLinkModel, AbstractGelModel, AbstractGelModelMeta, + AbstractGelObjectBacklinksModel, AbstractGelSourceModel, DefaultValue, GelType, @@ -19,6 +20,7 @@ ) from ._descriptors import ( + AbstractGelObjectModel, AbstractGelProxyModel, AnyLinkDescriptor, AnyPropertyDescriptor, @@ -27,6 +29,7 @@ ComputedMultiPropertyDescriptor, ComputedPropertyDescriptor, GelLinkModelDescriptor, + GelObjectBacklinksModelDescriptor, LinkDescriptor, ModelFieldDescriptor, MultiLinkDescriptor, @@ -118,6 +121,8 @@ "AbstractGelLinkModel", "AbstractGelModel", "AbstractGelModelMeta", + "AbstractGelObjectBacklinksModel", + "AbstractGelObjectModel", "AbstractGelProxyModel", "AbstractGelSourceModel", "AbstractLinkSet", @@ -144,6 +149,7 @@ "DateTimeLike", "DefaultValue", "GelLinkModelDescriptor", + "GelObjectBacklinksModelDescriptor", "GelPrimitiveType", "GelScalarType", "GelType", diff --git a/gel/_internal/_qbmodel/_abstract/_base.py b/gel/_internal/_qbmodel/_abstract/_base.py index b64a35c2..e5e50c2a 100644 --- a/gel/_internal/_qbmodel/_abstract/_base.py +++ b/gel/_internal/_qbmodel/_abstract/_base.py @@ -227,6 +227,27 @@ def __edgeql__(self) -> tuple[type, str]: ) +class AbstractGelObjectBacklinksModel( + AbstractGelSourceModel, + _qb.GelTypeMetadata, +): + if TYPE_CHECKING: + # Whether the model was copied by reference and must + # be copied by value before being accessed by the user. + __gel_copied_by_ref__: bool + + class __gel_reflection__( # noqa: N801 + _qb.GelSourceMetadata.__gel_reflection__, + _qb.GelTypeMetadata.__gel_reflection__, + ): + pass + + @classmethod + def __edgeql_qb_expr__(cls) -> _qb.Expr: # pyright: ignore [reportIncompatibleMethodOverride] + this_type = cls.__gel_reflection__.type_name + return _qb.SchemaSet(type_=this_type) + + class AbstractGelLinkModel(AbstractGelSourceModel): if TYPE_CHECKING: # Whether the model was copied by reference and must diff --git a/gel/_internal/_qbmodel/_abstract/_descriptors.py b/gel/_internal/_qbmodel/_abstract/_descriptors.py index 5a5dc55c..f1821c81 100644 --- a/gel/_internal/_qbmodel/_abstract/_descriptors.py +++ b/gel/_internal/_qbmodel/_abstract/_descriptors.py @@ -32,8 +32,9 @@ from ._base import ( GelType, - AbstractGelModel, AbstractGelLinkModel, + AbstractGelModel, + AbstractGelObjectBacklinksModel, is_gel_type, maybe_collapse_object_type_variant_union, LITERAL_TAG_FIELDS, @@ -170,6 +171,7 @@ def get( return _UNRESOLVED_TYPE else: source: _qb.Expr + is_backlink = issubclass(owner, AbstractGelObjectBacklinksModel) if expr is not None: source = expr.__gel_metadata__ elif _qb.is_expr_compatible(owner): @@ -188,6 +190,7 @@ def get( source=source, name=self.__gel_name__, is_lprop=False, + is_backlink=is_backlink, ) return _qb.AnnotatedPath(t, metadata) @@ -896,3 +899,70 @@ def proxy_link( new=proxy_type.__gel_validate__(new), proxy_type=proxy_type, ) + + +_OBMT_co = TypeVar( + "_OBMT_co", bound=AbstractGelObjectBacklinksModel, covariant=True +) +"""Derived model type""" + + +class AbstractGelObjectModel(AbstractGelModel): + __backlinks__: GelObjectBacklinksModelDescriptor[ + AbstractGelObjectBacklinksModel + ] + + +class GelObjectBacklinksModelDescriptor( + _typing_parametric.PickleableClassParametricType, + _qb.AbstractFieldDescriptor, + Generic[_OBMT_co], +): + _backlinks_model_class: ClassVar[type[_OBMT_co]] # type: ignore [misc] + + def __set_name__(self, owner: type[Any], name: str) -> None: + self._backlinks_model_attr = name + + @overload + def __get__( + self, instance: None, owner: type[Any], / + ) -> type[_OBMT_co]: ... + + @overload + def __get__( + self, instance: Any, owner: type[Any] | None = None, / + ) -> _OBMT_co: ... + + def __get__( + self, + instance: Any | None, + owner: type[Any] | None = None, + /, + ) -> type[_OBMT_co] | _OBMT_co: + if instance is None: + return self._backlinks_model_class + else: + attr = self._backlinks_model_attr + backlinks: _OBMT_co | None = instance.__dict__.get(attr) + if backlinks is None: + backlinks = ( + self._backlinks_model_class.__gel_model_construct__({}) + ) + instance.__dict__[attr] = backlinks + + return backlinks + + def get( + self, + owner: type[AbstractGelObjectModel], + expr: _qb.BaseAlias | None = None, + ) -> Any: + source: _qb.Expr + if expr is not None: + source = expr.__gel_metadata__ + elif _qb.is_expr_compatible(owner): + source = _qb.edgeql_qb_expr(owner) + else: + raise AssertionError("missing source for backlink path") + + return _qb.AnnotatedExpr(owner.__backlinks__, source) # pyright: ignore [reportGeneralTypeIssues] diff --git a/gel/_internal/_qbmodel/_abstract/_methods.py b/gel/_internal/_qbmodel/_abstract/_methods.py index 621beb45..bcba7d83 100644 --- a/gel/_internal/_qbmodel/_abstract/_methods.py +++ b/gel/_internal/_qbmodel/_abstract/_methods.py @@ -19,12 +19,18 @@ from gel._internal import _qb from gel._internal._schemapath import ( TypeNameIntersection, + TypeNameExpr, ) from gel._internal import _type_expression from gel._internal._xmethod import classonlymethod -from ._base import AbstractGelModel +from ._base import AbstractGelModel, AbstractGelObjectBacklinksModel +from ._descriptors import ( + GelObjectBacklinksModelDescriptor, + ModelFieldDescriptor, + field_descriptor, +) from ._expressions import ( add_filter, add_limit, @@ -241,8 +247,8 @@ def __edgeql_qb_expr__(cls) -> _qb.Expr: # pyright: ignore [reportIncompatibleM return _qb.SchemaSet(type_=this_type) -_T_Lhs = TypeVar("_T_Lhs", bound="type[AbstractGelModel]") -_T_Rhs = TypeVar("_T_Rhs", bound="type[AbstractGelModel]") +_T_Lhs = TypeVar("_T_Lhs", bound="AbstractGelModel") +_T_Rhs = TypeVar("_T_Rhs", bound="AbstractGelModel") class BaseGelModelIntersection( @@ -256,6 +262,14 @@ class BaseGelModelIntersection( rhs: ClassVar[type[AbstractGelModel]] +class BaseGelModelIntersectionBacklinks( + AbstractGelObjectBacklinksModel, + _type_expression.Intersection, +): + lhs: ClassVar[type[AbstractGelObjectBacklinksModel]] + rhs: ClassVar[type[AbstractGelObjectBacklinksModel]] + + T = TypeVar('T') U = TypeVar('U') @@ -308,18 +322,14 @@ def combine_dicts( type[AbstractGelModel], weakref.WeakKeyDictionary[ type[AbstractGelModel], - type[ - BaseGelModelIntersection[ - type[AbstractGelModel], type[AbstractGelModel] - ] - ], + type[BaseGelModelIntersection[AbstractGelModel, AbstractGelModel]], ], ] = weakref.WeakKeyDictionary() def create_intersection( - lhs: _T_Lhs, - rhs: _T_Rhs, + lhs: type[_T_Lhs], + rhs: type[_T_Rhs], ) -> type[BaseGelModelIntersection[_T_Lhs, _T_Rhs]]: """Create a runtime intersection type which acts like a GelModel.""" @@ -347,7 +357,6 @@ class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # n rhs.__gel_reflection__.type_name, ) ) - pointers = ptr_reflections @classmethod @@ -358,6 +367,7 @@ def object( "Type expressions schema objects are inaccessible" ) + # Create the resulting intersection type result = type( f"({lhs.__name__} & {rhs.__name__})", (BaseGelModelIntersection,), @@ -365,45 +375,18 @@ def object( 'lhs': lhs, 'rhs': rhs, '__gel_reflection__': __gel_reflection__, + "__gel_proxied_dunders__": frozenset( + { + "__backlinks__", + } + ), }, ) - # Generate path aliases for pointers. - # - # These are used to generate the appropriate path prefix when getting - # pointers in shapes. - # - # For example, doing `Foo.select(foo=lambda x: x.is_(Bar).bar)` - # will produce the query: - # select Foo { [is Bar].bar } - lhs_prefix = _qb.PathTypeIntersectionPrefix( - type_=__gel_reflection__.type_name, - type_filter=lhs.__gel_reflection__.type_name, - ) - rhs_prefix = _qb.PathTypeIntersectionPrefix( - type_=__gel_reflection__.type_name, - type_filter=rhs.__gel_reflection__.type_name, - ) - - def process_path_alias( - p_name: str, - p_refl: _qb.GelPointerReflection, - path_alias: _qb.PathAlias, - source: _qb.Expr, - ) -> _qb.PathAlias: - return _qb.PathAlias( - path_alias.__gel_origin__, - _qb.Path( - type_=p_refl.type, - source=source, - name=p_name, - is_lprop=False, - ), - ) - - path_aliases: dict[str, _qb.PathAlias] = combine_dicts( + # Generate field descriptors. + descriptors: dict[str, ModelFieldDescriptor] = combine_dicts( { - p_name: process_path_alias(p_name, p_refl, path_alias, lhs_prefix) + p_name: field_descriptor(result, p_name, path_alias.__gel_origin__) for p_name, p_refl in lhs.__gel_reflection__.pointers.items() if ( hasattr(lhs, p_name) @@ -412,7 +395,7 @@ def process_path_alias( ) }, { - p_name: process_path_alias(p_name, p_refl, path_alias, rhs_prefix) + p_name: field_descriptor(result, p_name, path_alias.__gel_origin__) for p_name, p_refl in rhs.__gel_reflection__.pointers.items() if ( hasattr(rhs, p_name) @@ -421,11 +404,99 @@ def process_path_alias( ) }, ) - for p_name, path_alias in path_aliases.items(): - setattr(result, p_name, path_alias) + for p_name, descriptor in descriptors.items(): + setattr(result, p_name, descriptor) + + # Generate backlinks if required (they should generally be) + if (lhs_backlinks := getattr(lhs, "__backlinks__", None)) and ( + rhs_backlinks := getattr(rhs, "__backlinks__", None) + ): + backlinks_model = create_intersection_backlinks( + lhs_backlinks, + rhs_backlinks, + result, + __gel_reflection__.type_name, + ) + setattr( # noqa: B010 + result, + "__backlinks__", + GelObjectBacklinksModelDescriptor[backlinks_model](), # type: ignore [valid-type] + ) if lhs not in _type_intersection_cache: _type_intersection_cache[lhs] = weakref.WeakKeyDictionary() _type_intersection_cache[lhs][rhs] = result return result + + +def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: + if lhs == rhs: + return (lhs,) + elif issubclass(lhs, rhs): + return (lhs, rhs) + elif issubclass(rhs, lhs): + return (rhs, lhs) + else: + return (lhs, rhs) + + +def create_intersection_backlinks( + lhs_backlinks: type[AbstractGelObjectBacklinksModel], + rhs_backlinks: type[AbstractGelObjectBacklinksModel], + result: type[BaseGelModelIntersection[Any, Any]], + result_type_name: TypeNameExpr, +) -> type[AbstractGelObjectBacklinksModel]: + reflection = type( + "__gel_reflection__", + _order_base_types( + lhs_backlinks.__gel_reflection__, + rhs_backlinks.__gel_reflection__, + ), + { + "name": result_type_name, + "type_name": result_type_name, + "pointers": ( + lhs_backlinks.__gel_reflection__.pointers + | rhs_backlinks.__gel_reflection__.pointers + ), + }, + ) + + # Generate field descriptors for backlinks. + field_descriptors: dict[str, ModelFieldDescriptor] = combine_dicts( + { + p_name: field_descriptor(result, p_name, path_alias.__gel_origin__) + for p_name in lhs_backlinks.__gel_reflection__.pointers + if ( + hasattr(lhs_backlinks, p_name) + and (path_alias := getattr(lhs_backlinks, p_name, None)) + is not None + and isinstance(path_alias, _qb.PathAlias) + ) + }, + { + p_name: field_descriptor(result, p_name, path_alias.__gel_origin__) + for p_name in rhs_backlinks.__gel_reflection__.pointers + if ( + hasattr(rhs_backlinks, p_name) + and (path_alias := getattr(rhs_backlinks, p_name, None)) + is not None + and isinstance(path_alias, _qb.PathAlias) + ) + }, + ) + + backlinks = type( + f"__{result_type_name.name}_backlinks__", + (BaseGelModelIntersectionBacklinks,), + { + 'lhs': lhs_backlinks, + 'rhs': rhs_backlinks, + '__gel_reflection__': reflection, + '__module__': __name__, + **field_descriptors, + }, + ) + + return backlinks diff --git a/gel/_internal/_qbmodel/_pydantic/__init__.py b/gel/_internal/_qbmodel/_pydantic/__init__.py index f47abec8..6a8b4401 100644 --- a/gel/_internal/_qbmodel/_pydantic/__init__.py +++ b/gel/_internal/_qbmodel/_pydantic/__init__.py @@ -32,6 +32,8 @@ GelLinkModel, GelModel, GelModelMeta, + GelObjectBacklinksModel, + GelObjectModel, LinkClassNamespace, ProxyModel, ) @@ -56,6 +58,8 @@ "GelLinkModel", "GelModel", "GelModelMeta", + "GelObjectBacklinksModel", + "GelObjectModel", "IdProperty", "LinkClassNamespace", "MultiProperty", diff --git a/gel/_internal/_qbmodel/_pydantic/_models.py b/gel/_internal/_qbmodel/_pydantic/_models.py index fa0019f7..c7b8f805 100644 --- a/gel/_internal/_qbmodel/_pydantic/_models.py +++ b/gel/_internal/_qbmodel/_pydantic/_models.py @@ -184,6 +184,20 @@ def __new__( # noqa: PYI034 else: cls.__gel_id_shape__ = None + # Add any proxied dunders from base classes + proxied_dunders: frozenset[str] | None = None + if cls_proxied_dunders := getattr(cls, "__gel_proxied_dunders__", ()): + proxied_dunders = cls_proxied_dunders + for base in bases: + if base_proxied_dunders := getattr( + base, "__gel_proxied_dunders__", () + ): + if proxied_dunders is None: + proxied_dunders = frozenset() + proxied_dunders |= base_proxied_dunders + if proxied_dunders: + cls.__gel_proxied_dunders__ = proxied_dunders + return cls def __setattr__(cls, name: str, value: Any, /) -> None: # noqa: N805 @@ -1323,6 +1337,39 @@ def model_copy( return copied +class GelObjectModel( + GelModel, + _abstract.AbstractGelObjectModel, + __gel_root_class__=True, +): + # Base class for object classes. + __gel_proxied_dunders__: ClassVar[frozenset[str]] = frozenset( + { + "__backlinks__", + } + ) + + +class GelObjectBacklinksModel( + GelSourceModel, + _abstract.AbstractGelObjectBacklinksModel, + __gel_root_class__=True, +): + # Base class for __backlinks__ classes. + __slots__ = ("__gel_copied_by_ref__",) + + def __getstate__(self) -> dict[Any, Any]: + state = super().__getstate__() + state["__gel_copied_by_ref__"] = getattr( + self, "__gel_copied_by_ref__", False + ) + return state + + def __setstate__(self, state: dict[Any, Any]) -> None: + super().__setstate__(state) + self.__gel_copied_by_ref__ = state["__gel_copied_by_ref__"] + + class GelLinkModel( GelSourceModel, _abstract.AbstractGelLinkModel, diff --git a/gel/_internal/_schemapath.py b/gel/_internal/_schemapath.py index a2dc21ae..4653eca8 100644 --- a/gel/_internal/_schemapath.py +++ b/gel/_internal/_schemapath.py @@ -304,7 +304,7 @@ def as_quoted_schema_name(self) -> str: @property def name(self) -> str: - return f"({' & '.join(a.name for a in self.args)})" + return f"{'_AND_'.join(a.name for a in self.args)}" @dataclasses.dataclass(frozen=True, kw_only=True) @@ -319,7 +319,7 @@ def as_quoted_schema_name(self) -> str: @property def name(self) -> str: - return f"({' | '.join(a.name for a in self.args)})" + return f"{'_OR_'.join(a.name for a in self.args)}" TypeNameExpr = TypeAliasType( diff --git a/gel/_internal/_testbase/_models.py b/gel/_internal/_testbase/_models.py index 63bd8322..911bb0a5 100644 --- a/gel/_internal/_testbase/_models.py +++ b/gel/_internal/_testbase/_models.py @@ -39,6 +39,7 @@ from gel._internal import _dirhash from gel._internal import _import_extras from gel._internal._codegen._models import PydanticModelsGenerator +from gel._internal._qbmodel._pydantic._models import GelModel from ._base import ( AsyncQueryTestCase, @@ -76,7 +77,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterator, Mapping, Sequence import pydantic - from gel._internal._qbmodel._pydantic._models import GelModel _unset = object() @@ -533,6 +533,30 @@ def assertPydanticPickles( getattr(model2, "__gel_changed_fields__", ...), ) + def _assertListEqualUnordered( + self, + expected: list[Any], + actual: list[Any], + ) -> None: + """Test that two collections have equal contents, though the order + may be different. + + Compares using the id of elements. Comparing lists of lists will not + work. + """ + + def key(value: Any) -> Any: + if isinstance(value, GelModel) and ( + model_id := getattr(value, "id", None) + ): + return model_id + else: + return id(value) + + return self.assertEqual( + sorted(expected, key=key), sorted(actual, key=key) + ) + def _assertObjectsWithFields( self, models: Collection[GelModel], diff --git a/gel/models/pydantic.py b/gel/models/pydantic.py index 0b36e424..2ce9801a 100644 --- a/gel/models/pydantic.py +++ b/gel/models/pydantic.py @@ -62,6 +62,7 @@ DateTimeLike, DefaultValue, GelLinkModelDescriptor, + GelObjectBacklinksModelDescriptor, GelScalarType, GelType, GelTypeMeta, @@ -83,6 +84,8 @@ GelLinkModel, GelModel, GelModelMeta, + GelObjectModel, + GelObjectBacklinksModel, IdProperty, LinkClassNamespace, ComputedLinkWithProps, @@ -152,6 +155,9 @@ "GelLinkModelDescriptor", "GelModel", "GelModelMeta", + "GelObjectBacklinksModel", + "GelObjectBacklinksModelDescriptor", + "GelObjectModel", "GelObjectTypeMetadata", "GelPointerReflection", "GelScalarType", diff --git a/tests/dbsetup/orm_qb.edgeql b/tests/dbsetup/orm_qb.edgeql index cb67dd08..c3e29d2e 100644 --- a/tests/dbsetup/orm_qb.edgeql +++ b/tests/dbsetup/orm_qb.edgeql @@ -277,23 +277,60 @@ insert Inh_AXA { axa := 10002, }; +insert Link_Inh_A { n := -1 }; +insert Link_Inh_A2 { n := -2 }; +insert Link_Inh_A3 { n := -3 }; + insert Link_Inh_A { n := 1, l := assert_exists((select Inh_A filter .a = 1 limit 1)), }; -insert Link_Inh_A { +insert Link_Inh_A2 { n := 4, l := assert_exists((select Inh_AB filter .a = 4 limit 1)), }; -insert Link_Inh_A { +insert Link_Inh_A2 { n := 7, l := assert_exists((select Inh_AC filter .a = 7 limit 1)), }; -insert Link_Inh_A { +insert Link_Inh_A3 { n := 13, l := assert_exists((select Inh_ABC filter .a = 13 limit 1)), }; -insert Link_Inh_A { +insert Link_Inh_A3 { n := 17, l := assert_exists((select Inh_AB_AC filter .a = 17 limit 1)), }; + +insert Link_Inh_AB { n := -1 }; +insert Link_Inh_AB { + n := 1004, + l := assert_exists((select Inh_AB filter .a = 4 limit 1)), +}; +insert Link_Inh_AB { + n := 1017, + l := assert_exists((select Inh_AB_AC filter .a = 17 limit 1)), +}; + +insert Link_Link_Inh_A { n := -1 }; + +insert Link_Link_Inh_A { + n := 1, + l := assert_exists((select Link_Inh_A filter .n = 1 limit 1)), +}; +insert Link_Link_Inh_A { + n := 4, + l := assert_exists((select Link_Inh_A filter .n = 4 limit 1)), +}; +insert Link_Link_Inh_A { + n := 7, + l := assert_exists((select Link_Inh_A filter .n = 7 limit 1)), +}; +insert Link_Link_Inh_A { + n := 13, + l := assert_exists((select Link_Inh_A filter .n = 13 limit 1)), +}; +insert Link_Link_Inh_A { + n := 17, + l := assert_exists((select Link_Inh_A filter .n = 17 limit 1)), +}; diff --git a/tests/dbsetup/orm_qb.gel b/tests/dbsetup/orm_qb.gel index ecef9227..9e30af07 100644 --- a/tests/dbsetup/orm_qb.gel +++ b/tests/dbsetup/orm_qb.gel @@ -617,6 +617,21 @@ type Link_Inh_A { on target delete allow; }; }; +type Link_Inh_A2 extending Link_Inh_A; +type Link_Inh_A3 extending Link_Inh_A2; +type Link_Inh_AB { + n: int64; + l: Inh_AB { + on target delete allow; + }; +}; + +type Link_Link_Inh_A { + n: int64; + l: Link_Inh_A { + on target delete allow; + }; +} function Read_Inh_A(x: Inh_A) -> int64 using (x.a ?? -1); function Read_Inh_A_Overload(x: Inh_A) -> int64 using (x.a ?? -1); diff --git a/tests/test_qb.py b/tests/test_qb.py index fdfb9533..9cd5a0aa 100644 --- a/tests/test_qb.py +++ b/tests/test_qb.py @@ -1800,33 +1800,54 @@ def test_qb_is_type_basic_09(self): ( default.Link_Inh_A, { - "n": 1, + "n": -1, + "l": None, + }, + ), + ( + default.Link_Inh_A2, + { + "n": -2, + "l": None, + }, + ), + ( + default.Link_Inh_A3, + { + "n": -3, "l": None, }, ), ( default.Link_Inh_A, + { + "n": 1, + "l": None, + }, + ), + ( + default.Link_Inh_A2, { "n": 4, "l": possible_targets[4], }, ), ( - default.Link_Inh_A, + default.Link_Inh_A2, { "n": 7, "l": None, }, ), ( - default.Link_Inh_A, + default.Link_Inh_A3, { "n": 13, "l": possible_targets[13], }, ), ( - default.Link_Inh_A, + default.Link_Inh_A3, { "n": 17, "l": possible_targets[17], @@ -1863,33 +1884,54 @@ def test_qb_is_type_basic_10(self): ( default.Link_Inh_A, { - "n": 1, + "n": -1, + "l": None, + }, + ), + ( + default.Link_Inh_A2, + { + "n": -2, + "l": None, + }, + ), + ( + default.Link_Inh_A3, + { + "n": -3, "l": None, }, ), ( default.Link_Inh_A, + { + "n": 1, + "l": None, + }, + ), + ( + default.Link_Inh_A2, { "n": 4, "l": possible_targets[4], }, ), ( - default.Link_Inh_A, + default.Link_Inh_A2, { "n": 7, "l": None, }, ), ( - default.Link_Inh_A, + default.Link_Inh_A3, { "n": 13, "l": possible_targets[13], }, ), ( - default.Link_Inh_A, + default.Link_Inh_A3, { "n": 17, "l": possible_targets[17], @@ -2113,6 +2155,451 @@ def test_qb_is_type_as_function_arg_03(self): ) self.assertEqual(sorted(result), [6, 13, 20]) + def test_qb_backlinks_01(self): + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Inh_A.__backlinks__.l.is_(default.Link_Inh_A) + result = self.client.query(query) + self._assertListEqualUnordered( + [ + link_inh_a_objs[1], + link_inh_a_objs[4], + link_inh_a_objs[7], + link_inh_a_objs[13], + link_inh_a_objs[17], + ], + result, + ) + + def test_qb_backlinks_02(self): + # Two unrelated links with the same name + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + link_inh_ab_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_AB.select(n=True)) + } + + # Check the ids of the un-typed backlinks + query_base = default.Inh_AB.__backlinks__.l + result_base = self.client.query(query_base) + expected_base = [ + link_inh_a_objs[4], + link_inh_a_objs[17], + link_inh_ab_objs[1004], + link_inh_ab_objs[1017], + ] + self._assertListEqualUnordered(expected_base, result_base) + + # with [is Link_Inh_A] + query_a = default.Inh_AB.__backlinks__.l.is_(default.Link_Inh_A) + result_a = self.client.query(query_a) + self._assertListEqualUnordered( + [link_inh_a_objs[4], link_inh_a_objs[17]], result_a + ) + + # with [is Link_Inh_AB] + query_ab = default.Inh_AB.__backlinks__.l.is_(default.Link_Inh_AB) + result_ab = self.client.query(query_ab) + self._assertListEqualUnordered( + [link_inh_ab_objs[1004], link_inh_ab_objs[1017]], result_ab + ) + + def test_qb_backlinks_03(self): + # Filter -> Backlink + from models.orm_qb import default, std + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Inh_A.filter( + lambda x: std.in_(x.a, {1, 4}) + ).__backlinks__.l.is_(default.Link_Inh_A) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_inh_a_objs[1], link_inh_a_objs[4]], result + ) + + def test_qb_backlinks_04(self): + # Intersection -> Backlink + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Inh_B.is_(default.Inh_A).__backlinks__.l.is_( + default.Link_Inh_A + ) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_inh_a_objs[4], link_inh_a_objs[13], link_inh_a_objs[17]], + result, + ) + + def test_qb_backlinks_05(self): + # Filter ->Intersection -> Backlink + from models.orm_qb import default, std + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = ( + default.Inh_B.filter(lambda x: std.in_(x.b, {5, 14})) + .is_(default.Inh_A) + .__backlinks__.l.is_(default.Link_Inh_A) + ) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_inh_a_objs[4], link_inh_a_objs[13]], result + ) + + def test_qb_backlinks_06(self): + # Intersection -> Filter -> Backlink + from models.orm_qb import default, std + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = ( + default.Inh_B.is_(default.Inh_A) + .filter(lambda x: std.in_(x.a, {4, 13})) + .__backlinks__.l.is_(default.Link_Inh_A) + ) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_inh_a_objs[4], link_inh_a_objs[13]], result + ) + + def test_qb_backlinks_07(self): + # Intersection -> Intersection -> Backlink + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = ( + default.Inh_B.is_(default.Inh_C) + .is_(default.Inh_A) + .__backlinks__.l.is_(default.Link_Inh_A) + ) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_inh_a_objs[13], link_inh_a_objs[17]], result + ) + + def test_qb_backlinks_08(self): + # Link -> Backlink + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Link_Inh_AB.l.__backlinks__.l.is_(default.Link_Inh_A) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_inh_a_objs[4], link_inh_a_objs[17]], result + ) + + def test_qb_backlinks_09(self): + # Link -> Filter -> Backlink + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Link_Inh_AB.l.filter(a=4).__backlinks__.l.is_( + default.Link_Inh_A + ) + result = self.client.query(query) + self._assertListEqualUnordered([link_inh_a_objs[4]], result) + + def test_qb_backlinks_10(self): + # Filter -> Link -> Backlink + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Link_Inh_AB.filter(n=1004).l.__backlinks__.l.is_( + default.Link_Inh_A + ) + result = self.client.query(query) + self._assertListEqualUnordered([link_inh_a_objs[4]], result) + + def test_qb_backlinks_11(self): + # Link -> Intersection -> Backlink + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Link_Inh_AB.l.is_(default.Inh_AC).__backlinks__.l.is_( + default.Link_Inh_A + ) + result = self.client.query(query) + self._assertListEqualUnordered([link_inh_a_objs[17]], result) + + def test_qb_backlinks_12(self): + # Intersection -> Link -> Backlink + from models.orm_qb import default + + link_inh_a_objs = { + obj.n: obj + for obj in self.client.query(default.Link_Inh_A.select(n=True)) + } + + query = default.Link_Inh_A.is_( + default.Link_Inh_A2 + ).l.__backlinks__.l.is_(default.Link_Inh_A) + result = self.client.query(query) + self._assertListEqualUnordered( + [ + link_inh_a_objs[4], + link_inh_a_objs[7], + link_inh_a_objs[13], + link_inh_a_objs[17], + ], + result, + ) + + def test_qb_backlinks_13(self): + # Backlink -> Backlink + from models.orm_qb import default + + link_link_inh_a_objs = { + obj.n: obj + for obj in self.client.query( + default.Link_Link_Inh_A.select(n=True) + ) + } + + query = default.Inh_AB.__backlinks__.l.is_( + default.Link_Inh_A + ).__backlinks__.l.is_(default.Link_Link_Inh_A) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_link_inh_a_objs[4], link_link_inh_a_objs[17]], result + ) + + def test_qb_backlinks_14(self): + # Filter -> Backlink -> Backlink + from models.orm_qb import default, std + + link_link_inh_a_objs = { + obj.n: obj + for obj in self.client.query( + default.Link_Link_Inh_A.select(n=True) + ) + } + + query = ( + default.Inh_A.filter(lambda x: std.in_(x.a, {4, 13})) + .__backlinks__.l.is_(default.Link_Inh_A) + .__backlinks__.l.is_(default.Link_Link_Inh_A) + ) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_link_inh_a_objs[4], link_link_inh_a_objs[13]], result + ) + + def test_qb_backlinks_15(self): + # Backlink -> Filter -> Backlink + from models.orm_qb import default, std + + link_link_inh_a_objs = { + obj.n: obj + for obj in self.client.query( + default.Link_Link_Inh_A.select(n=True) + ) + } + + query = ( + default.Inh_A.__backlinks__.l.is_(default.Link_Inh_A) + .filter(lambda x: std.in_(x.n, {4, 13})) + .__backlinks__.l.is_(default.Link_Link_Inh_A) + ) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_link_inh_a_objs[4], link_link_inh_a_objs[13]], result + ) + + def test_qb_backlinks_16(self): + # Backlink -> Backlink -> Filter + from models.orm_qb import default, std + + link_link_inh_a_objs = { + obj.n: obj + for obj in self.client.query( + default.Link_Link_Inh_A.select(n=True) + ) + } + + query = ( + default.Inh_A.__backlinks__.l.is_(default.Link_Inh_A) + .__backlinks__.l.is_(default.Link_Link_Inh_A) + .filter(lambda x: std.in_(x.n, {4, 13})) + ) + result = self.client.query(query) + self._assertListEqualUnordered( + [link_link_inh_a_objs[4], link_link_inh_a_objs[13]], result + ) + + def test_qb_backlinks_17(self): + # Shape ( Backlink ) + from models.orm_qb import default + + query = default.Inh_AB.select( + a=lambda x: x.__backlinks__.l.is_(default.Link_Inh_AB).limit(1).n + ) + result = self.client.query(query) + + self._assertObjectsWithFields( + result, + "a", + [ + ( + default.Inh_AB, + { + "a": 1004, + }, + ), + ( + default.Inh_AB, + { + "a": 1017, + }, + ), + ], + ) + + def test_qb_backlinks_18(self): + # Shape ( Filter -> Backlink ) + from models.orm_qb import default + + query = default.Inh_AB.select( + a=lambda x: x.filter(a=4) + .__backlinks__.l.is_(default.Link_Inh_AB) + .limit(1) + .n + ) + result = self.client.query(query) + + self._assertObjectsWithFields( + result, + "a", + [ + ( + default.Inh_AB, + { + "a": None, + }, + ), + ( + default.Inh_AB, + { + "a": 1004, + }, + ), + ], + ) + + def test_qb_backlinks_19(self): + # Shape ( Backlink -> Filter ) + from models.orm_qb import default, std + + query = default.Inh_AB.select( + a=lambda x: x.__backlinks__.l.is_(default.Link_Inh_AB) + .filter(lambda x: std.in_(x.l, default.Inh_AC)) + .limit(1) + .n + ) + result = self.client.query(query) + + self._assertObjectsWithFields( + result, + "a", + [ + ( + default.Inh_AB, + { + "a": None, + }, + ), + ( + default.Inh_AB, + { + "a": 1017, + }, + ), + ], + ) + + def test_qb_backlinks_20(self): + # Shape ( Backlink -> Backlink ) + from models.orm_qb import default + + query = default.Inh_AB.select( + a=lambda x: x.__backlinks__.l.is_(default.Link_Inh_A) + .__backlinks__.l.is_(default.Link_Link_Inh_A) + .limit(1) + .n + ) + result = self.client.query(query) + + self._assertObjectsWithFields( + result, + "a", + [ + ( + default.Inh_AB, + { + "a": 4, + }, + ), + ( + default.Inh_AB, + { + "a": 17, + }, + ), + ], + ) + + @tb.skip_typecheck + def test_qb_backlinks_error_01(self): + from models.orm_qb import default + + query = default.Inh_A.__backlinks__ + with self.assertRaisesRegex(ValueError, "unsupported query type"): + self.client.query(query) + class TestQueryBuilderModify(tb.ModelTestCase): """This test suite is for data manipulation using QB."""