Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 253 additions & 4 deletions gel/_internal/_codegen/_models/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,12 @@ class ModuleAspect(enum.Enum):
LATE = enum.auto()


@dataclasses.dataclass(frozen=True, kw_only=True)
class Backlink:
source: reflection.ObjectType
pointer: reflection.Pointer


class SchemaGenerator:
def __init__(
self,
Expand All @@ -451,6 +457,9 @@ def __init__(
self._std_modules: list[SchemaPath] = []
self._types: Mapping[str, reflection.Type] = {}
self._casts: reflection.CastMatrix
self._backlinks: dict[
reflection.ObjectType, dict[str, list[Backlink]]
] = {}
self._operators: reflection.OperatorMatrix
self._functions: list[reflection.Function]
self._globals: list[reflection.Global]
Expand Down Expand Up @@ -579,6 +588,7 @@ def run(self, outdir: pathlib.Path) -> tuple[Schema, set[pathlib.Path]]:
all_casts=self._casts,
all_operators=self._operators,
all_globals=self._globals,
all_backlinks=self._backlinks,
modules=self._modules,
schema_part=self._schema_part,
)
Expand All @@ -604,6 +614,7 @@ def run(self, outdir: pathlib.Path) -> tuple[Schema, set[pathlib.Path]]:
all_casts=self._casts,
all_operators=self._operators,
all_globals=self._globals,
all_backlinks=self._backlinks,
modules=all_modules,
schema_part=self._schema_part,
)
Expand Down Expand Up @@ -659,6 +670,24 @@ def introspect_schema(self) -> Schema:
if reflection.is_object_type(t):
name = t.schemapath
self._modules[name.parent]["object_types"][name.name] = t

for p in t.pointers:
if (
reflection.is_link(p)
# For now don't include std::BaseObject.__type__
# Users should just select on the appropriate type
and p.name != "__type__"
):
target = self._types[p.target_id]
assert isinstance(target, reflection.ObjectType)
if target not in self._backlinks:
self._backlinks[target] = {}
if p.name not in self._backlinks[target]:
self._backlinks[target][p.name] = []
self._backlinks[target][p.name].append(
Backlink(source=t, pointer=p)
)

elif reflection.is_scalar_type(t):
name = t.schemapath
self._modules[name.parent]["scalar_types"][name.name] = t
Expand Down Expand Up @@ -694,6 +723,7 @@ def _generate_common_types(
all_casts=self._casts,
all_operators=self._operators,
all_globals=self._globals,
all_backlinks=self._backlinks,
modules=self._modules,
schema_part=self._schema_part,
)
Expand Down Expand Up @@ -944,6 +974,9 @@ def __init__(
all_casts: reflection.CastMatrix,
all_operators: reflection.OperatorMatrix,
all_globals: list[reflection.Global],
all_backlinks: Mapping[
reflection.ObjectType, Mapping[str, Sequence[Backlink]]
],
modules: Collection[SchemaPath],
schema_part: reflection.SchemaPart,
) -> None:
Expand All @@ -954,6 +987,7 @@ def __init__(
self._casts = all_casts
self._operators = all_operators
self._globals = all_globals
self._backlinks = all_backlinks
schema_obj_type = None
for t in all_types.values():
self._types_by_name[t.name] = t
Expand Down Expand Up @@ -2340,7 +2374,7 @@ def write_generic_types(
unpack = self.import_name("typing_extensions", "Unpack")
geltype = self.import_name(BASE_IMPL, "GelType")
geltypemeta = self.import_name(BASE_IMPL, "GelTypeMeta")
gelmodel = self.import_name(BASE_IMPL, "GelModel")
gelobjectmodel = self.import_name(BASE_IMPL, "GelObjectModel")
gelmodelmeta = self.import_name(BASE_IMPL, "GelModelMeta")
anytuple = self.import_name(BASE_IMPL, "AnyTuple")
anynamedtuple = self.import_name(BASE_IMPL, "AnyNamedTuple")
Expand All @@ -2367,7 +2401,7 @@ def write_generic_types(
geltype,
],
SchemaPath("std", "anyobject"): [
gelmodel,
gelobjectmodel,
"anytype",
],
SchemaPath("std", "anytuple"): [
Expand Down Expand Up @@ -4086,6 +4120,9 @@ def _mangle_default_shape(name: str) -> str:
if proplinks:
self.write_object_type_link_models(objtype)

if objtype.name != "std::FreeObject":
self._write_object_backlinks(objtype)

anyobject_meta = self.get_object(
SchemaPath("std", "__anyobject_meta__"),
aspect=ModuleAspect.SHAPES,
Expand Down Expand Up @@ -4213,6 +4250,16 @@ def write_id_computed(
self.write(f"__gel_type_class__ = __{name}_ops__")
if objtype.name == "std::BaseObject":
write_id_attr(objtype, "RequiredId")

if objtype.name != "std::FreeObject":
backlinks_model_name = self._mangle_backlinks_model_name(name)
g_oblm_desc = self.import_name(
BASE_IMPL, "GelObjectBacklinksModelDescriptor"
)
oblm_desc = f"{g_oblm_desc}[{backlinks_model_name}]"
self.write(f"__backlinks__: {oblm_desc} = {oblm_desc}()")
self.write()

self._write_base_object_type_body(objtype, include_tname=True)
with self.type_checking():
self._write_object_type_qb_methods(objtype)
Expand Down Expand Up @@ -4333,6 +4380,208 @@ def write_id_computed(

self.write()

@staticmethod
def _mangle_backlinks_model_name(name: str) -> str:
return f"__{name}_backlinks__"

def _write_object_backlinks(
self,
objtype: reflection.ObjectType,
) -> None:
type_name = objtype.schemapath
name = type_name.name

schema_path = self.import_name(BASE_IMPL, "SchemaPath")
parametric_type_name = self.import_name(
BASE_IMPL, "ParametricTypeName"
)
computed_multi_link = self.import_name(BASE_IMPL, "ComputedMultiLink")
std_base_object_t = self.get_object(
SchemaPath.from_segments("std", "BaseObject"),
aspect=ModuleAspect.SHAPES,
)

objtype_bases = [
base_type
for base_ref in objtype.bases
if (base_type := self._types.get(base_ref.id, None))
if isinstance(base_type, reflection.ObjectType)
]
objtype_name = type_name.as_python_code(
schema_path, parametric_type_name
)

backlinks_model_name = self._mangle_backlinks_model_name(name)

backlinks_class_bases: list[str]
backlinks_reflection_class_bases: list[str]

# Backlinks' reflections' pointers combine the current types
# backlinks with those of their base types.
#
# Eg. With `type A` and `type B extending A`,
# - __A_backlinks__.__gel_reflection__.pointers will contain
# all backlinks to A (and BaseObject)
# - __B_backlinks__.__gel_reflection__.pointers will contain
# all backlinks to B, backlinks to A.
#
# No base type backlinks are needed for BaseObject.
backlinks_reflection_pointer_bases: list[str]

if objtype.name == "std::BaseObject":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on why this is special?

# __BaseObject_backlinks__ derives from GelObjectBacklinksModel
# while all other backlinks derive from it directly on indirectly.
object_backlinks_model = self.import_name(
BASE_IMPL, "GelObjectBacklinksModel"
)
backlinks_class_bases = [object_backlinks_model]
backlinks_reflection_class_bases = [
f"{object_backlinks_model}.__gel_reflection__"
]
backlinks_reflection_pointer_bases = []
else:
backlinks_class_bases = [
self.get_object(
SchemaPath(
base_type.schemapath.parent,
self._mangle_backlinks_model_name(
base_type.schemapath.name
),
),
aspect=ModuleAspect.SHAPES,
)
for base_type in objtype_bases
]
backlinks_reflection_class_bases = [
f"{bbt}.__gel_reflection__" for bbt in backlinks_class_bases
]
backlinks_reflection_pointer_bases = backlinks_class_bases

object_backlinks = self._backlinks.get(objtype, {})

# The backlinks model class
with self._class_def(backlinks_model_name, backlinks_class_bases):
with self._class_def(
"__gel_reflection__", backlinks_reflection_class_bases
):
self.write(f"name = {objtype_name}")
self.write(f"type_name = {objtype_name}")
self._write_backlinks_pointers_reflection(
object_backlinks, backlinks_reflection_pointer_bases
)

for backlink_name in object_backlinks:
backlink_t = f"{computed_multi_link}[{std_base_object_t}]"
self.write(f"{backlink_name}: {backlink_t}")

self.export(backlinks_model_name)

self.write()

def _write_backlinks_pointers_reflection(
self,
object_backlinks: Mapping[str, Sequence[Backlink]],
backlinks_reflection_pointer_bases: Sequence[str],
) -> None:
dict_ = self.import_name(
"builtins", "dict", import_time=ImportTime.typecheck
)
str_ = self.import_name(
"builtins", "str", import_time=ImportTime.typecheck
)
gel_ptr_ref = self.import_name(
BASE_IMPL,
"GelPointerReflection",
import_time=ImportTime.runtime
if object_backlinks
else ImportTime.typecheck,
)
lazyclassproperty = self.import_name(BASE_IMPL, "LazyClassProperty")
ptr_ref_t = f"{dict_}[{str_}, {gel_ptr_ref}]"
with self._classmethod_def(
"pointers",
[],
ptr_ref_t,
decorators=(f'{lazyclassproperty}["{ptr_ref_t}"]',),
):
if object_backlinks:
self.write(f"my_ptrs: {ptr_ref_t} = {{")
classes = {
"SchemaPath": self.import_name(BASE_IMPL, "SchemaPath"),
"ParametricTypeName": self.import_name(
BASE_IMPL, "ParametricTypeName"
),
"GelPointerReflection": gel_ptr_ref,
"Cardinality": self.import_name(BASE_IMPL, "Cardinality"),
"PointerKind": self.import_name(BASE_IMPL, "PointerKind"),
"StdBaseObject": self.get_object(
SchemaPath.from_segments("std", "BaseObject"),
aspect=ModuleAspect.SHAPES,
),
}
with self.indented():
for (
backlink_name,
backlink_values,
) in object_backlinks.items():
r = self._reflect_backlink(
backlink_name, backlink_values, classes
)
self.write(f"{backlink_name!r}: {r},")
self.write("}")
else:
self.write(f"my_ptrs: {ptr_ref_t} = {{}}")

if backlinks_reflection_pointer_bases:
pp = "__gel_reflection__.pointers"
ret = self.format_list(
"return ({list})",
[
"my_ptrs",
*_map_name(
lambda s: f"{s}.{pp}",
backlinks_reflection_pointer_bases,
),
],
separator=" | ",
carry_separator=True,
)
else:
ret = "return my_ptrs"

self.write(ret)

self.write()

def _reflect_backlink(
self,
name: str,
backlinks: Sequence[Backlink],
classes: dict[str, str],
) -> str:
kwargs: dict[str, str] = {
"name": repr(name),
"type": classes["StdBaseObject"],
"kind": (
f"{classes['PointerKind']}({str(reflection.PointerKind.Link)!r})"
),
"cardinality": (
f"{classes['Cardinality']}({str(reflection.Cardinality.Many)!r})"
),
"computed": "True",
"readonly": "True",
"has_default": "False",
"mutable": "False",
}

# For now don't get any back link props
kwargs["properties"] = "None"

return self.format_list(
f"{classes['GelPointerReflection']}({{list}})",
[f"{k}={v}" for k, v in kwargs.items()],
)

@contextlib.contextmanager
def _object_type_variant(
self,
Expand All @@ -4354,8 +4603,8 @@ def _object_type_variant(
)

if not list(variant_bases):
gel_model = self.import_name(BASE_IMPL, "GelModel")
bases.append(gel_model)
gel_object_model = self.import_name(BASE_IMPL, "GelObjectModel")
bases.append(gel_object_model)

with self._class_def(
variant,
Expand Down
1 change: 1 addition & 0 deletions gel/_internal/_qb/_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class PathExpr(AtomicExpr):
name: str
is_lprop: bool = False
is_link: bool = False
is_backlink: bool = False

def subnodes(self) -> Iterable[Node]:
return (self.source,)
Expand Down
2 changes: 2 additions & 0 deletions gel/_internal/_qb/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ def __edgeql_expr__(self, *, ctx: ScopeContext) -> str:
source = current.source
if isinstance(source, PathPrefix) and source.lprop_pivot:
step = f"@{_edgeql.quote_ident(current.name)}"
elif current.is_backlink:
step = f".<{_edgeql.quote_ident(current.name)}"
else:
step = f".{_edgeql.quote_ident(current.name)}"
steps.append(step)
Expand Down
Loading