Skip to content

Commit 71d8b53

Browse files
committed
Add hack for bool overload.
1 parent a917cb2 commit 71d8b53

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3721,16 +3721,67 @@ def _write_potentially_overlapping_overloads(
37213721
# SEE ABOVE: This is what we actually want.
37223722
# key=lambda o: (generality_key(o), o.edgeql_signature), # noqa: ERA001, E501
37233723
)
3724+
base_generic_overload: dict[_Callable_T, _Callable_T] = {}
37243725

37253726
for overload in overloads:
37263727
overload_signatures[overload] = {}
3728+
3729+
if overload.schemapath == SchemaPath('std', 'IF'):
3730+
# HACK: Pretend the base overload of std::IF is generic on
3731+
# anyobject.
3732+
#
3733+
# The base overload of std::IF is
3734+
# (anytype, std::bool, anytype) -> anytype
3735+
#
3736+
# However, this causes an overlap with overloading for bool
3737+
# arguments since
3738+
# (anytype, builtin.bool, anytype) -> anytype
3739+
# overlaps with
3740+
# (std::bool, builtin.bool, std::bool) -> std::bool
3741+
#
3742+
# We resolve this by generating the specializations for anytype
3743+
# but using anyobject as the base generic type.
3744+
3745+
def anytype_to_anyobject(
3746+
refl_type: reflection.Type,
3747+
default: reflection.Type | reflection.TypeRef,
3748+
) -> reflection.Type | reflection.TypeRef:
3749+
if isinstance(refl_type, reflection.PseudoType):
3750+
return self._types_by_name["anyobject"]
3751+
return default
3752+
3753+
base_generic_overload[overload] = dataclasses.replace(
3754+
overload,
3755+
params=[
3756+
dataclasses.replace(
3757+
param,
3758+
type=anytype_to_anyobject(
3759+
param.get_type(self._types), param.type
3760+
),
3761+
)
3762+
for param in overload.params
3763+
],
3764+
return_type=anytype_to_anyobject(
3765+
overload.get_return_type(self._types),
3766+
overload.return_type,
3767+
),
3768+
)
3769+
37273770
for param in param_getter(overload):
37283771
param_overload_map[param.key].add(overload)
37293772
param_type = param.get_type(self._types)
37303773
# Unwrap the variadic type (it is reflected as an array of T)
37313774
if param.kind is reflection.CallableParamKind.Variadic:
37323775
if reflection.is_array_type(param_type):
37333776
param_type = param_type.get_element_type(self._types)
3777+
3778+
if (
3779+
overload.schemapath == SchemaPath('std', 'IF')
3780+
and param_type.is_pseudo
3781+
):
3782+
# Also generate the base signature using anyobject
3783+
param_type = self._types_by_name["anyobject"]
3784+
37343785
# Start with the base parameter type
37353786
overload_signatures[overload][param.key] = [param_type]
37363787

@@ -3842,7 +3893,10 @@ def specialization_sort_key(t: reflection.Type) -> int:
38423893
for overload in overloads:
38433894
if overload_specs := overloads_specializations.get(overload):
38443895
expanded_overloads.extend(overload_specs)
3845-
expanded_overloads.append(overload)
3896+
if overload in base_generic_overload:
3897+
expanded_overloads.append(base_generic_overload[overload])
3898+
else:
3899+
expanded_overloads.append(overload)
38463900
overloads = expanded_overloads
38473901

38483902
overload_order = {overload: i for i, overload in enumerate(overloads)}

gel/_internal/_reflection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
ScalarType,
6969
TupleType,
7070
Type,
71+
TypeRef,
7172
compare_type_generality,
7273
fetch_types,
7374
is_abstract_type,
@@ -126,6 +127,7 @@
126127
"Type",
127128
"TypeKind",
128129
"TypeModifier",
130+
"TypeRef",
129131
"compare_callable_generality",
130132
"compare_type_generality",
131133
"fetch_branch_state",

0 commit comments

Comments
 (0)