Skip to content

Commit c9c79ac

Browse files
authored
Fix batched forward constant get (#2143)
* Fix batched forward constant get * fix * Update typeunstable.jl * Update typeunstable.jl * Update typeunstable.jl * fixup
1 parent 2a24bb5 commit c9c79ac

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

src/errors.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ function julia_error(
9494
B::LLVM.API.LLVMBuilderRef,
9595
)::LLVM.API.LLVMValueRef
9696
msg = Base.unsafe_string(cstr)
97+
julia_error(msg, val, errtype, data, data2, B)
98+
end
99+
100+
function julia_error(
101+
msg::String,
102+
val::LLVM.API.LLVMValueRef,
103+
errtype::API.ErrorType,
104+
data::Ptr{Cvoid},
105+
data2::LLVM.API.LLVMValueRef,
106+
B::LLVM.API.LLVMBuilderRef,
107+
)::LLVM.API.LLVMValueRef
97108
bt = nothing
98109
ir = nothing
99110
if val != C_NULL
@@ -331,7 +342,9 @@ function julia_error(
331342
sres
332343
end
333344
shadowres = insert_value!(prevbb, shadowres, res, idx - 1)
334-
push!(created, shadowres)
345+
if shadowres isa LLVM.Instruction
346+
push!(created, shadowres)
347+
end
335348
end
336349
return shadowres
337350
end

src/rules/typeunstablerules.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR)
10671067
shadowres = UndefValue(
10681068
LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))),
10691069
)
1070+
position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal)))
10701071
for idx = 1:width
10711072
shadowres = insert_value!(B, shadowres, normal, idx - 1)
10721073
end
@@ -1534,8 +1535,13 @@ end
15341535
end
15351536
origops = collect(operands(orig))
15361537
width = get_width(gutils)
1537-
if !is_constant_value(gutils, origops[1])
1538-
shadowin = invert_pointer(gutils, origops[1], B)
1538+
if !is_constant_value(gutils, origops[1]) || !get_runtime_activity(gutils)
1539+
shadowin = if !is_constant_value(gutils, origops[1])
1540+
invert_pointer(gutils, origops[1], B)
1541+
else
1542+
estr = "Mismatched activity for: " * string(orig) * " const input " *string(origops[1]) * ", differentiable return"
1543+
LLVM.Value(julia_error(estr, orig.ref, API.ET_MixedActivityError, gutils.ref, origops[1].ref, B.ref))
1544+
end
15391545
if width == 1
15401546
args = LLVM.Value[
15411547
shadowin
@@ -1565,6 +1571,7 @@ end
15651571
shadowres = UndefValue(
15661572
LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(normal))),
15671573
)
1574+
position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(normal)))
15681575
for idx = 1:width
15691576
shadowres = insert_value!(B, shadowres, normal, idx - 1)
15701577
end

test/typeunstable.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,24 @@ end
101101
res = Enzyme.autodiff(Forward, toactivepair, BatchDuplicated(2.7f0, (2.0f0, 5.0f0)), BatchDuplicated(3.1, (3.0, 7.0)))
102102
@test res[1][1] 2.7f0 * 3.0 + 2.0f0 * 3.1
103103
@test res[1][2] 2.7f0 * 7.0 + 5.0f0 * 3.1
104-
end
104+
end
105+
106+
struct InsFwdNormal1{T<:Real}
107+
σ::T
108+
end
109+
110+
struct InsFwdNormal2{T<:Real}
111+
σ::T
112+
end
113+
114+
insfwdlogpdf(d, x) = d.σ
115+
116+
function insfwdfunc(x)
117+
dists = [InsFwdNormal1{Float64}(1.0), InsFwdNormal2{Float64}(1.0)]
118+
return sum(Base.Fix2(insfwdlogpdf, x), dists)
119+
end
120+
121+
@testset "Forward Batch Constant insertion" begin
122+
res = Enzyme.gradient(Enzyme.Forward, insfwdfunc, [0.5, 0.7])[1]
123+
@test res [0.0, 0.0]
124+
end

0 commit comments

Comments
 (0)