@@ -3721,16 +3721,67 @@ def _write_potentially_overlapping_overloads(
37213721 # SEE ABOVE: This is what we actually want.
37223722 # key=lambda o: (generality_key(o), o.edgeql_signature), # noqa: ERA001, E501
37233723 )
3724+ base_generic_overload : dict [_Callable_T , _Callable_T ] = {}
37243725
37253726 for overload in overloads :
37263727 overload_signatures [overload ] = {}
3728+
3729+ if overload .schemapath == SchemaPath ('std' , 'IF' ):
3730+ # HACK: Pretend the base overload of std::IF is generic on
3731+ # anyobject.
3732+ #
3733+ # The base overload of std::IF is
3734+ # (anytype, std::bool, anytype) -> anytype
3735+ #
3736+ # However, this causes an overlap with overloading for bool
3737+ # arguments since
3738+ # (anytype, builtin.bool, anytype) -> anytype
3739+ # overlaps with
3740+ # (std::bool, builtin.bool, std::bool) -> std::bool
3741+ #
3742+ # We resolve this by generating the specializations for anytype
3743+ # but using anyobject as the base generic type.
3744+
3745+ def anytype_to_anyobject (
3746+ refl_type : reflection .Type ,
3747+ default : reflection .Type | reflection .TypeRef ,
3748+ ) -> reflection .Type | reflection .TypeRef :
3749+ if isinstance (refl_type , reflection .PseudoType ):
3750+ return self ._types_by_name ["anyobject" ]
3751+ return default
3752+
3753+ base_generic_overload [overload ] = dataclasses .replace (
3754+ overload ,
3755+ params = [
3756+ dataclasses .replace (
3757+ param ,
3758+ type = anytype_to_anyobject (
3759+ param .get_type (self ._types ), param .type
3760+ ),
3761+ )
3762+ for param in overload .params
3763+ ],
3764+ return_type = anytype_to_anyobject (
3765+ overload .get_return_type (self ._types ),
3766+ overload .return_type ,
3767+ ),
3768+ )
3769+
37273770 for param in param_getter (overload ):
37283771 param_overload_map [param .key ].add (overload )
37293772 param_type = param .get_type (self ._types )
37303773 # Unwrap the variadic type (it is reflected as an array of T)
37313774 if param .kind is reflection .CallableParamKind .Variadic :
37323775 if reflection .is_array_type (param_type ):
37333776 param_type = param_type .get_element_type (self ._types )
3777+
3778+ if (
3779+ overload .schemapath == SchemaPath ('std' , 'IF' )
3780+ and param_type .is_pseudo
3781+ ):
3782+ # Also generate the base signature using anyobject
3783+ param_type = self ._types_by_name ["anyobject" ]
3784+
37343785 # Start with the base parameter type
37353786 overload_signatures [overload ][param .key ] = [param_type ]
37363787
@@ -3842,7 +3893,10 @@ def specialization_sort_key(t: reflection.Type) -> int:
38423893 for overload in overloads :
38433894 if overload_specs := overloads_specializations .get (overload ):
38443895 expanded_overloads .extend (overload_specs )
3845- expanded_overloads .append (overload )
3896+ if overload in base_generic_overload :
3897+ expanded_overloads .append (base_generic_overload [overload ])
3898+ else :
3899+ expanded_overloads .append (overload )
38463900 overloads = expanded_overloads
38473901
38483902 overload_order = {overload : i for i , overload in enumerate (overloads )}
0 commit comments