Skip to content

Commit 8a84970

Browse files
authored
loosen fwd runtime activity blas restrictions (#2398)
* loosen fwd runtime activity blas restrictions * fix'
1 parent aa67a5a commit 8a84970

File tree

4 files changed

+111
-19
lines changed

4 files changed

+111
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
3939
CEnum = "0.4, 0.5"
4040
ChainRulesCore = "1"
4141
EnzymeCore = "0.8.8"
42-
Enzyme_jll = "0.0.179"
42+
Enzyme_jll = "0.0.180"
4343
GPUArraysCore = "0.1.6, 0.2"
4444
GPUCompiler = "1.3"
4545
LLVM = "6.1, 7, 8, 9"

src/errors.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,15 @@ function julia_error(
310310
println(io)
311311
end
312312
end
313-
emit_error(B, nothing, msg2, EnzymeNoDerivativeError)
313+
if data2 != C_NULL
314+
data2 = LLVM.Value(data2)
315+
if value_type(data2) != LLVM.IntType(1)
316+
data2 = nothing
317+
end
318+
else
319+
data2 = nothing
320+
end
321+
emit_error(B, nothing, msg2, EnzymeNoDerivativeError, data2)
314322
return C_NULL
315323
end
316324
throw(NoDerivativeException(msg, ir, bt))

src/jlrt.jl

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,48 @@ function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM
199199
T_prjlvalue = LLVM.PointerType(T_jlvalue, 12)
200200
FT = LLVM.FunctionType(T_void, [T_prjlvalue])
201201
fn, _ = get_function!(mod, "jl_throw", FT)
202-
call!(B, FT, fn, LLVM.Value[val])
202+
cb = call!(B, FT, fn, LLVM.Value[val])
203+
LLVM.API.LLVMAddCallSiteAttribute(
204+
cb,
205+
reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex),
206+
EnumAttribute("noreturn"),
207+
)
208+
return cb
209+
end
210+
211+
function emit_conditional_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value), @nospecialize(cond::LLVM.Value))::LLVM.Value
212+
curent_bb = position(B)
213+
fn = LLVM.parent(curent_bb)
214+
mod = LLVM.parent(fn)
215+
T_void = LLVM.VoidType()
216+
T_jlvalue = LLVM.StructType(LLVMType[])
217+
T_prjlvalue = LLVM.PointerType(T_jlvalue, 10)
218+
FT = LLVM.FunctionType(T_void, [T_prjlvalue, LLVM.IntType(1)])
219+
220+
name = "jl_conditional_throw"
221+
if haskey(functions(mod), name)
222+
fn = functions(mod)[name]
223+
else
224+
fn = LLVM.Function(mod, "jl_conditional_throw", FT)
225+
linkage!(fn, LLVM.API.LLVMInternalLinkage)
226+
err, rcond = LLVM.parameters(fn)
227+
builder = LLVM.IRBuilder()
228+
entry = BasicBlock(fn, "entry")
229+
errb = BasicBlock(fn, "err")
230+
exitb = BasicBlock(fn, "errb")
231+
position!(builder, entry)
232+
br!(builder, rcond, errb, exitb)
233+
position!(builder, errb)
234+
err = addrspacecast!(builder, err, LLVM.PointerType(T_jlvalue, 12))
235+
thrown = emit_jl_throw!(builder, err)
236+
unreachable!(builder)
237+
position!(builder, exitb)
238+
ret!(builder)
239+
240+
push!(LLVM.function_attributes(fn), LLVM.EnumAttribute("alwaysinline", 0))
241+
end
242+
243+
call!(B, FT, fn, LLVM.Value[val, cond])
203244
end
204245

205246
function emit_box_int32!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value
@@ -1004,7 +1045,7 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, @nospecialize(N::LLV
10041045
allocate_sret!(B, N)
10051046
end
10061047

1007-
function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::Union{String, LLVM.Value}, @nospecialize(errty::Type) = EnzymeRuntimeException)
1048+
function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::Union{String, LLVM.Value}, @nospecialize(errty::Type) = EnzymeRuntimeException, @nospecialize(cond::Union{Nothing, LLVM.Value}) = nothing)
10081049
curent_bb = position(B)
10091050
fn = LLVM.parent(curent_bb)
10101051
mod = LLVM.parent(fn)
@@ -1061,24 +1102,30 @@ function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.I
10611102
err = emit_allocobj!(B, errty)
10621103
err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10))
10631104
store!(B, string, err2)
1064-
emit_jl_throw!(
1065-
B,
1066-
addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12)),
1067-
)
1105+
if cond !== nothing
1106+
emit_conditional_throw!(B, err, cond)
1107+
else
1108+
emit_jl_throw!(
1109+
B,
1110+
addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12)),
1111+
)
1112+
end
10681113
end
10691114

10701115
# 2. Call error function and insert unreachable
1071-
LLVM.API.LLVMAddCallSiteAttribute(
1072-
ct,
1073-
reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex),
1074-
EnumAttribute("noreturn"),
1075-
)
1076-
if EnzymeMutabilityException != errty
1077-
LLVM.API.LLVMAddCallSiteAttribute(
1078-
ct,
1079-
reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex),
1080-
StringAttribute("enzyme_error"),
1081-
)
1116+
if cond === nothing
1117+
LLVM.API.LLVMAddCallSiteAttribute(
1118+
ct,
1119+
reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex),
1120+
EnumAttribute("noreturn"),
1121+
)
1122+
if EnzymeMutabilityException != errty
1123+
LLVM.API.LLVMAddCallSiteAttribute(
1124+
ct,
1125+
reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex),
1126+
StringAttribute("enzyme_error"),
1127+
)
1128+
end
10821129
end
10831130
return ct
10841131
end

test/internal_rules.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,43 @@ end
282282
end
283283
end
284284

285+
function two_blas(a, b)
286+
a = copy(a)
287+
@inline LinearAlgebra.LAPACK.potrf!('L', a)
288+
@inline LinearAlgebra.LAPACK.potrf!('L', b)
289+
return a[1,1] + b[1,1]
290+
end
291+
292+
@testset "Forward Mode runtime activity" begin
293+
294+
a = [2.7 3.5; 7.4 9.2]
295+
da = [7.2 5.3; 4.7 2.9]
296+
297+
b = [3.1 5.6; 13 19]
298+
db = [1.3 6.5; .13 .19]
299+
300+
res = Enzyme.autodiff(Forward, two_blas, Duplicated(a, da), Duplicated(b, db))[1]
301+
@test res 2.5600654222812564
302+
303+
a = [2.7 3.5; 7.4 9.2]
304+
da = [7.2 5.3; 4.7 2.9]
305+
306+
b = [3.1 5.6; 13 19]
307+
db = [1.3 6.5; .13 .19]
308+
309+
res = Enzyme.autodiff(set_runtime_activity(Forward), two_blas, Duplicated(a, da), Duplicated(b, db))[1]
310+
@test res 2.5600654222812564
311+
312+
a = [2.7 3.5; 7.4 9.2]
313+
da = [7.2 5.3; 4.7 2.9]
314+
315+
b = [3.1 5.6; 13 19]
316+
db = [1.3 6.5; .13 .19]
317+
318+
@test_throws Enzyme.Compiler.EnzymeNoDerivativeError Enzyme.autodiff(set_runtime_activity(Forward), two_blas, Duplicated(a, da), Duplicated(b, b))
319+
320+
end
321+
285322
@testset "Cholesky" begin
286323
function symmetric_definite(n :: Int=10)
287324
α = one(Float64)

0 commit comments

Comments
 (0)