From b4e1455320148ddb3ed638edbe54043e1ee7030c Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:27:42 +0200 Subject: [PATCH 1/4] apply_type_with_promotion --- src/Tracing.jl | 83 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 015a2ccff7..a47ddd4a87 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -736,7 +736,7 @@ Base.@nospecializeinfer function traced_type_inner( end if !isempty(subParms) - TT2 = Core.apply_type(T.name.wrapper, subParms...) + TT2 = apply_type_with_promotion(T.name.wrapper, subParms) else TT2 = T end @@ -745,6 +745,14 @@ Base.@nospecializeinfer function traced_type_inner( if fieldcount(T) == fieldcount(TT2) legal = true for f in 1:fieldcount(T) + if isa(Base.unwrap_unionall(T.name.wrapper).types[f], TypeVar) + # The field is constrained by a TypeVar directly, + # so we don't need to check. + # (The check below would fail if the typevar was promoted as + # we don't get the same result when calling traced_type_inner + # on the field type directly.) + continue + end subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) subTT = traced_type_inner(subT, seen3, mode, track_numbers, sharding, runtime) @@ -857,6 +865,79 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type,Any},Dict{Type,Type}}() # $(Expr(:meta, :generated, traced_type_generator)) # end +""" +This function tries to apply the param types to the wrapper type. +When there's a constraint conflict, it tries to resolve it by promoting the conflicting types. The new param type is then propagated in any param type that depends on it. +""" +function apply_type_with_promotion(wrapper, params, relevant_typevars=typevar_dict(wrapper)) + unwrapped = Base.unwrap_unionall(wrapper) # remove all the typevars + + new_params = copy(params) + + changed = true + iter = 0 + while changed && iter < 100 + changed = false + for (i, param) in enumerate(params) + # Add back the typevars to only one of the parameters: + rewrapped = Base.rewrap_unionall(unwrapped.parameters[i], wrapper) + + sz = @ccall jl_subtype_env_size(rewrapped::Any)::Cint + arr = Array{Any}(undef, sz) + + # Verify that the currently selected parameter subtypes the param in the wrapper type. + # In the process, `arr` is filled with with the required types for each parameter used by the current parameter: + is_subtype = + (@ccall jl_subtype_env( + params[i]::Any, rewrapped::Any, arr::Ptr{Any}, sz::Cint + )::Cint) == 1 + !is_subtype && error( + "Failed to find a valid type for typevar $i ($(params[i]) <: $(rewrapped) == false)", + ) + + # Check whether the required types are supertypes of all the parameter types we currently have: + current_unionall = rewrapped + for value in arr + # Peel open the unionall to figure out which typevar each `value` corresponds to: + typevar = current_unionall.var + current_unionall = current_unionall.body + + # `param` might have other typevars that don't occur in `wrapper`, + # here we first check if the typevar is actually relevant: + if haskey(relevant_typevars, typevar) + param_i = relevant_typevars[typevar] + value <: params[param_i] && continue + + # Found a conflict! Figure out a new param type by promoting: + promoted = promote_type(value, params[param_i]) + new_params[param_i] = promoted + + if value != promoted + # This happens when `value` lost the promotion battle. + # At this point, we need to update the problematic parameter in`value`. + d = typevar_dict(rewrapped) + v = [param.parameters...] + v[d[typevar]] = promoted + new_params[i] = params[i] = apply_type_with_promotion(rewrapped, v) + end + changed = true + end + end + end + params .= new_params + iter += 1 + end + return Core.apply_type(wrapper, params...) +end + +function typevar_dict(t) + d = Dict() + for (i, name) in enumerate(Base.unwrap_unionall(t).parameters) + d[name] = i + end + return d +end + Base.@assume_effects :total @inline function traced_type( T::Type, ::Val{mode}, track_numbers::Type, sharding, runtime ) where {mode} From 12b1d6390d3a45e5eed38a6d371f1085c0e97bec Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 20 Sep 2025 00:03:57 +0200 Subject: [PATCH 2/4] add test --- test/tracing.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/tracing.jl b/test/tracing.jl index 21a4f6559b..8b40e3d358 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -252,6 +252,20 @@ end Reactant.XLA.runtime(), ) end + @testset "apply_type_with_promotion" begin + struct Bar{T} + b::T + end + struct Foo{T, B<:Bar{T}, AT<:AbstractArray{T}} + a::AT + b::B + end + @test Reactant.apply_type_with_promotion(Foo, [Float64, Bar{Float64}, Reactant.TracedRArray{Float64, 1}]) == Foo{ + TracedRNumber{Float64}, + Bar{TracedRNumber{Float64}}, + Reactant.TracedRArray{Float64, 1}, + } + end end @testset "specialized dispatches" begin From 138b9dba7f8cab62c6c120eb25cf171ec380e0d9 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 20 Sep 2025 00:09:42 +0200 Subject: [PATCH 3/4] test formatting --- test/tracing.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/tracing.jl b/test/tracing.jl index 8b40e3d358..3711f3279d 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -256,14 +256,16 @@ end struct Bar{T} b::T end - struct Foo{T, B<:Bar{T}, AT<:AbstractArray{T}} + struct Foo{T,B<:Bar{T},AT<:AbstractArray{T}} a::AT b::B end - @test Reactant.apply_type_with_promotion(Foo, [Float64, Bar{Float64}, Reactant.TracedRArray{Float64, 1}]) == Foo{ + @test Reactant.apply_type_with_promotion( + Foo, [Float64, Bar{Float64}, Reactant.TracedRArray{Float64,1}] + ) == Foo{ TracedRNumber{Float64}, Bar{TracedRNumber{Float64}}, - Reactant.TracedRArray{Float64, 1}, + Reactant.TracedRArray{Float64,1}, } end end From c9b4f129a41c79e7e5380499402e42853cb8882d Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 20 Sep 2025 00:10:15 +0200 Subject: [PATCH 4/4] remove `new_params` --- src/Tracing.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index a47ddd4a87..01994452c7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -871,8 +871,7 @@ When there's a constraint conflict, it tries to resolve it by promoting the conf """ function apply_type_with_promotion(wrapper, params, relevant_typevars=typevar_dict(wrapper)) unwrapped = Base.unwrap_unionall(wrapper) # remove all the typevars - - new_params = copy(params) + params = [params...] changed = true iter = 0 @@ -910,7 +909,7 @@ function apply_type_with_promotion(wrapper, params, relevant_typevars=typevar_di # Found a conflict! Figure out a new param type by promoting: promoted = promote_type(value, params[param_i]) - new_params[param_i] = promoted + params[param_i] = promoted if value != promoted # This happens when `value` lost the promotion battle. @@ -918,13 +917,12 @@ function apply_type_with_promotion(wrapper, params, relevant_typevars=typevar_di d = typevar_dict(rewrapped) v = [param.parameters...] v[d[typevar]] = promoted - new_params[i] = params[i] = apply_type_with_promotion(rewrapped, v) + params[i] = apply_type_with_promotion(rewrapped, v) end changed = true end end end - params .= new_params iter += 1 end return Core.apply_type(wrapper, params...)