diff --git a/src/Tracing.jl b/src/Tracing.jl index 015a2ccff7..01994452c7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -736,7 +736,7 @@ Base.@nospecializeinfer function traced_type_inner( end if !isempty(subParms) - TT2 = Core.apply_type(T.name.wrapper, subParms...) + TT2 = apply_type_with_promotion(T.name.wrapper, subParms) else TT2 = T end @@ -745,6 +745,14 @@ Base.@nospecializeinfer function traced_type_inner( if fieldcount(T) == fieldcount(TT2) legal = true for f in 1:fieldcount(T) + if isa(Base.unwrap_unionall(T.name.wrapper).types[f], TypeVar) + # The field is constrained by a TypeVar directly, + # so we don't need to check. + # (The check below would fail if the typevar was promoted as + # we don't get the same result when calling traced_type_inner + # on the field type directly.) + continue + end subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) subTT = traced_type_inner(subT, seen3, mode, track_numbers, sharding, runtime) @@ -857,6 +865,77 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}() # $(Expr(:meta, :generated, traced_type_generator)) # end +""" +This function tries to apply the param types to the wrapper type. +When there's a constraint conflict, it tries to resolve it by promoting the conflicting types. The new param type is then propagated in any param type that depends on it. +""" +function apply_type_with_promotion(wrapper, params, relevant_typevars=typevar_dict(wrapper)) + unwrapped = Base.unwrap_unionall(wrapper) # remove all the typevars + params = [params...] + + changed = true + iter = 0 + while changed && iter < 100 + changed = false + for (i, param) in enumerate(params) + # Add back the typevars to only one of the parameters: + rewrapped = Base.rewrap_unionall(unwrapped.parameters[i], wrapper) + + sz = @ccall jl_subtype_env_size(rewrapped::Any)::Cint + arr = Array{Any}(undef, sz) + + # Verify that the currently selected parameter subtypes the param in the wrapper type. + # In the process, `arr` is filled with with the required types for each parameter used by the current parameter: + is_subtype = + (@ccall jl_subtype_env( + params[i]::Any, rewrapped::Any, arr::Ptr{Any}, sz::Cint + )::Cint) == 1 + !is_subtype && error( + "Failed to find a valid type for typevar $i ($(params[i]) <: $(rewrapped) == false)", + ) + + # Check whether the required types are supertypes of all the parameter types we currently have: + current_unionall = rewrapped + for value in arr + # Peel open the unionall to figure out which typevar each `value` corresponds to: + typevar = current_unionall.var + current_unionall = current_unionall.body + + # `param` might have other typevars that don't occur in `wrapper`, + # here we first check if the typevar is actually relevant: + if haskey(relevant_typevars, typevar) + param_i = relevant_typevars[typevar] + value <: params[param_i] && continue + + # Found a conflict! Figure out a new param type by promoting: + promoted = promote_type(value, params[param_i]) + params[param_i] = promoted + + if value != promoted + # This happens when `value` lost the promotion battle. + # At this point, we need to update the problematic parameter in`value`. + d = typevar_dict(rewrapped) + v = [param.parameters...] + v[d[typevar]] = promoted + params[i] = apply_type_with_promotion(rewrapped, v) + end + changed = true + end + end + end + iter += 1 + end + return Core.apply_type(wrapper, params...) +end + +function typevar_dict(t) + d = Dict() + for (i, name) in enumerate(Base.unwrap_unionall(t).parameters) + d[name] = i + end + return d +end + Base.@assume_effects :total @inline function traced_type( T::Type, ::Val{mode}, track_numbers::Type, sharding, runtime ) where {mode} diff --git a/test/tracing.jl b/test/tracing.jl index 21a4f6559b..3711f3279d 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -252,6 +252,22 @@ end Reactant.XLA.runtime(), ) end + @testset "apply_type_with_promotion" begin + struct Bar{T} + b::T + end + struct Foo{T,B<:Bar{T},AT<:AbstractArray{T}} + a::AT + b::B + end + @test Reactant.apply_type_with_promotion( + Foo, [Float64, Bar{Float64}, Reactant.TracedRArray{Float64,1}] + ) == Foo{ + TracedRNumber{Float64}, + Bar{TracedRNumber{Float64}}, + Reactant.TracedRArray{Float64,1}, + } + end end @testset "specialized dispatches" begin