Skip to content

Commit 22818bf

Browse files
authored
Fix fwd to not have ref on active rtfix (#2142)
* Fix fwd to not have ref on active rtfix * Update runtests.jl
1 parent 2bfc9b5 commit 22818bf

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

src/errors.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ function julia_error(
286286
end
287287

288288
illegalVal = nothing
289+
mode = get_mode(gutils)
289290

290291
function make_replacement(@nospecialize(cur::LLVM.Value), prevbb::LLVM.IRBuilder)::LLVM.Value
291292
ncur = new_from_original(gutils, cur)
@@ -308,15 +309,27 @@ function julia_error(
308309
isa(cur, LLVM.ConstantExpr) &&
309310
cur == data2
310311
if width == 1
311-
res = emit_allocobj!(prevbb, Base.RefValue{TT})
312-
push!(created, res)
313-
return res
312+
if mode == API.DEM_ForwardMode
313+
instance = make_zero(obj)
314+
return unsafe_to_llvm(prevbb, instance)
315+
else
316+
res = emit_allocobj!(prevbb, Base.RefValue{TT})
317+
push!(created, res)
318+
return res
319+
end
314320
else
315321
shadowres = UndefValue(
316322
LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(cur))),
317323
)
318324
for idx = 1:width
319-
res = emit_allocobj!(prevbb, Base.RefValue{TT})
325+
res = if mode == API.DEM_ForwardMode
326+
instance = make_zero(obj)
327+
unsafe_to_llvm(prevbb, instance)
328+
else
329+
sres = emit_allocobj!(prevbb, Base.RefValue{TT})
330+
push!(created, sres)
331+
sres
332+
end
320333
shadowres = insert_value!(prevbb, shadowres, res, idx - 1)
321334
push!(created, shadowres)
322335
end

test/runtests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,34 @@ end
12721272
@test dweights[2] 20.
12731273
end
12741274

1275+
1276+
abstract type AbsFwdType end
1277+
1278+
# Two copies of the same type.
1279+
struct FwdNormal1{T<:Real} <: AbsFwdType
1280+
σ::T
1281+
end
1282+
1283+
struct FwdNormal2{T<:Real} <: AbsFwdType
1284+
σ::T
1285+
end
1286+
1287+
fwdlogpdf(d) = d.σ
1288+
1289+
function absactfunc(x)
1290+
dists = AbsFwdType[FwdNormal1{Float64}(1.0), FwdNormal2{Float64}(x)]
1291+
res = Vector{Float64}(undef, 2)
1292+
for i in 1:length(dists)
1293+
@inbounds res[i] = fwdlogpdf(dists[i])
1294+
end
1295+
return @inbounds res[1] + @inbounds res[2]
1296+
end
1297+
1298+
@testset "Forward Mode active runtime activity" begin
1299+
res = Enzyme.autodiff(Enzyme.Forward, Enzyme.Const(absactfunc), Duplicated(2.7, 3.1))
1300+
@test res[1] 3.1
1301+
end
1302+
12751303
# dot product (https://github.com/EnzymeAD/Enzyme.jl/issues/495)
12761304
@testset "Dot product" for T in (Float32, Float64)
12771305
xx = rand(T, 10)

0 commit comments

Comments
 (0)