-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
I am having an issue tracing a function that branches during a reduction. I've attached an MWE including an error I get on Julia 1.10 and the most recent released Reactant version.
using Reactant, FillArrays
struct ImageDirichlet{T, A <: AbstractMatrix{T}, S}
α::A
α0::T
lmnB::S
function ImageDirichlet(α::AbstractMatrix{T}) where {T}
α0 = sum(α)
lmnB = sum(log, α) - log(α0)
return new{T, typeof(α), typeof(lmnB)}(α, α0, lmnB)
end
end
function ImageDirichlet(α::Real, nx::Int, ny::Int)
return ImageDirichlet(FillArrays.Fill(α, nx, ny))
end
Base.size(d::ImageDirichlet) = size(d.α)
function logpdf(d::ImageDirichlet, x::AbstractMatrix)
l = dirichlet_lpdf(d.α, d.lmnB, x)
return l
end
function dirichlet_lpdf(α, lmnB, x)
s = mapreduce(+, x, α) do xi, αi
myxlogy(αi - 1, xi)
end
return s - lmnB
end
function myxlogy(x::Number, y::Number)
result = x * log(y)
r = ifelse(iszero(x) && isnan(y), zero(result), result)
return r
end
d = ImageDirichlet(1.0, 4, 4)
x = rand(4,4)
x ./= sum(x)
xr = Reactant.to_rarray(x)
@jit logpdf(d, xr)
ERROR: setfield!: immutable struct of type Float64 cannot be changed
Stacktrace:
[1] traced_setfield!
@ ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:93 [inlined]
[2] traced_setfield_buffer!(::Val{…}, cache_dict::IdDict{…}, val::Float64, concrete_res::Tuple{…}, obj::FillArrays.Fill{…}, field::Int64, path::Tuple{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:167
[3] traced_setfield_buffer!(runtime::Val{…}, cache_dict::IdDict{…}, concrete_res::Tuple{…}, obj::FillArrays.Fill{…}, field::Int64, path::Tuple{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:161
[4] macro expansion
@ ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:3252 [inlined]
[5] (::Reactant.Compiler.Thunk{…})(::ImageDirichlet{…}, ::ConcretePJRTArray{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:3927
[6] top-level scope
@ REPL[6]:1This may be related to using FillArrays. If I change the definition to
function ImageDirichlet(α::Real, nx::Int, ny::Int)
return ImageDirichlet(fill(α, nx, ny))
endI get the error
ERROR: TypeError: non-boolean (Reactant.TracedRNumber{Bool}) used in boolean context
Stacktrace:
[1] myxlogy
@ ./REPL[42]:3 [inlined]
[2] (::Nothing)(none::typeof(myxlogy), none::Reactant.TracedRNumber{Float64}, none::Reactant.TracedRNumber{Float64})
@ Reactant ./<missing>:0
[3] string
@ ./strings/io.jl:189 [inlined]
[4] macro expansion
@ ~/.julia/packages/Reactant/Gn2f9/src/Ops.jl:50 [inlined]
[5] log
@ ~/.julia/packages/Reactant/Gn2f9/src/TracedRNumber.jl:527 [inlined]
[6] myxlogy
@ ./REPL[42]:2 [inlined]
[7] call_with_reactant(::Reactant.MustThrowError, ::typeof(myxlogy), ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…})
@ Reactant ~/.julia/packages/Reactant/Gn2f9/src/utils.jl:0
[8] #21
@ ./REPL[41]:3 [inlined]
[9] (::Nothing)(none::var"#21#22", none::Reactant.TracedRNumber{Float64}, none::Reactant.TracedRNumber{Float64})
@ Reactant ./<missing>:0
[10] TracedRNumber
@ ~/.julia/packages/Reactant/Gn2f9/src/TracedRNumber.jl:181 [inlined]
[11] convert
@ ./number.jl:7 [inlined]
[12] _promote
@ ./promotion.jl:370 [inlined]
[13] promote
@ ./promotion.jl:393 [inlined]
[14] -
@ ./promotion.jl:424 [inlined]
[15] #21
@ ./REPL[41]:3 [inlined]
[16] call_with_reactant(::var"#21#22", ::Reactant.TracedRNumber{Float64}, ::Reactant.TracedRNumber{Float64})
@ Reactant ~/.julia/packages/Reactant/Gn2f9/src/utils.jl:0
[17] make_mlir_fn(f::var"#21#22", args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/Gn2f9/src/TracedUtils.jl:348
[18] elem_apply(::Function, ::Reactant.TracedRArray{Float64, 2}, ::Reactant.TracedRArray{Float64, 2})
@ Reactant.TracedUtils ~/.julia/packages/Reactant/Gn2f9/src/TracedUtils.jl:1155
[19] overloaded_map(f::Function, x::Reactant.TracedRArray{Float64, 2}, xs::Matrix{Float64})
@ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/Gn2f9/src/TracedRArray.jl:1085
[20] #map
@ ~/.julia/packages/Reactant/Gn2f9/src/Overlay.jl:213 [inlined]
[21] (::Nothing)(none::typeof(map), none::var"#21#22", none::Reactant.TracedRArray{Float64, 2}, none::Tuple{Matrix{Float64}})
@ Reactant ./<missing>:0
[22] #map
@ ~/.julia/packages/Reactant/Gn2f9/src/Overlay.jl:208 [inlined]
[23] call_with_reactant(::typeof(map), ::var"#21#22", ::Reactant.TracedRArray{Float64, 2}, ::Matrix{Float64})
@ Reactant ~/.julia/packages/Reactant/Gn2f9/src/utils.jl:0
[24] #mapreduce#824
@ ./reducedim.jl:361 [inlined]
[25] (::Nothing)(none::Base.var"##mapreduce#824", none::@Kwargs{}, none::typeof(mapreduce), none::Function, none::Function, none::Tuple{…})
@ Reactant ./<missing>:0
[26] call_with_reactant(::Base.var"##mapreduce#824", ::@Kwargs{}, ::typeof(mapreduce), ::Function, ::Function, ::Reactant.TracedRArray{…}, ::Union{…})
@ Reactant reducedim.jl:361
[27] mapreduce
@ ./reducedim.jl:361 [inlined]
[28] dirichlet_lpdf
@ ./REPL[41]:2 [inlined]
[29] (::Nothing)(none::typeof(dirichlet_lpdf), none::Matrix{Float64}, none::Float64, none::Reactant.TracedRArray{Float64, 2})
@ Reactant ./<missing>:0
[30] mapreduce
@ ./reducedim.jl:361 [inlined]
[31] dirichlet_lpdf
@ ./REPL[41]:2 [inlined]
[32] call_with_reactant(::typeof(dirichlet_lpdf), ::Matrix{Float64}, ::Float64, ::Reactant.TracedRArray{Float64, 2})
@ Reactant ~/.julia/packages/Reactant/Gn2f9/src/utils.jl:0
[33] logpdf
@ ./REPL[40]:2 [inlined]
[34] (::Nothing)(none::typeof(logpdf), none::ImageDirichlet{…}, none::Reactant.TracedRArray{…})
@ Reactant ./<missing>:0
[35] getproperty
@ ./Base.jl:37 [inlined]
[36] logpdf
@ ./REPL[40]:2 [inlined]
[37] call_with_reactant(::typeof(logpdf), ::ImageDirichlet{Float64, Matrix{…}, Float64}, ::Reactant.TracedRArray{Float64, 2})
@ Reactant ~/.julia/packages/Reactant/Gn2f9/src/utils.jl:0
[38] make_mlir_fn(f::typeof(logpdf), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/Gn2f9/src/TracedUtils.jl:348
[39] make_mlir_fn
@ ~/.julia/packages/Reactant/Gn2f9/src/TracedUtils.jl:277 [inlined]
[40] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(logpdf), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:1700
[41] compile_mlir!
@ ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:1662 [inlined]
[42] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:3651
[43] compile_xla
@ ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:3623 [inlined]
[44] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:3727
[45] top-level scope
@ ~/.julia/packages/Reactant/Gn2f9/src/Compiler.jl:2796
Some type information was truncated. Use `show(err)` to see complete types.Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels