@@ -199,7 +199,48 @@ function emit_jl_throw!(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM
199199 T_prjlvalue = LLVM. PointerType (T_jlvalue, 12 )
200200 FT = LLVM. FunctionType (T_void, [T_prjlvalue])
201201 fn, _ = get_function! (mod, " jl_throw" , FT)
202- call! (B, FT, fn, LLVM. Value[val])
202+ cb = call! (B, FT, fn, LLVM. Value[val])
203+ LLVM. API. LLVMAddCallSiteAttribute (
204+ cb,
205+ reinterpret (LLVM. API. LLVMAttributeIndex, LLVM. API. LLVMAttributeFunctionIndex),
206+ EnumAttribute (" noreturn" ),
207+ )
208+ return cb
209+ end
210+
211+ function emit_conditional_throw! (B:: LLVM.IRBuilder , @nospecialize (val:: LLVM.Value ), @nospecialize (cond:: LLVM.Value )):: LLVM.Value
212+ curent_bb = position (B)
213+ fn = LLVM. parent (curent_bb)
214+ mod = LLVM. parent (fn)
215+ T_void = LLVM. VoidType ()
216+ T_jlvalue = LLVM. StructType (LLVMType[])
217+ T_prjlvalue = LLVM. PointerType (T_jlvalue, 10 )
218+ FT = LLVM. FunctionType (T_void, [T_prjlvalue, LLVM. IntType (1 )])
219+
220+ name = " jl_conditional_throw"
221+ if haskey (functions (mod), name)
222+ fn = functions (mod)[name]
223+ else
224+ fn = LLVM. Function (mod, " jl_conditional_throw" , FT)
225+ linkage! (fn, LLVM. API. LLVMInternalLinkage)
226+ err, rcond = LLVM. parameters (fn)
227+ builder = LLVM. IRBuilder ()
228+ entry = BasicBlock (fn, " entry" )
229+ errb = BasicBlock (fn, " err" )
230+ exitb = BasicBlock (fn, " errb" )
231+ position! (builder, entry)
232+ br! (builder, rcond, errb, exitb)
233+ position! (builder, errb)
234+ err = addrspacecast! (builder, err, LLVM. PointerType (T_jlvalue, 12 ))
235+ thrown = emit_jl_throw! (builder, err)
236+ unreachable! (builder)
237+ position! (builder, exitb)
238+ ret! (builder)
239+
240+ push! (LLVM. function_attributes (fn), LLVM. EnumAttribute (" alwaysinline" , 0 ))
241+ end
242+
243+ call! (B, FT, fn, LLVM. Value[val, cond])
203244end
204245
205246function emit_box_int32! (B:: LLVM.IRBuilder , @nospecialize (val:: LLVM.Value )):: LLVM.Value
@@ -1004,7 +1045,7 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, @nospecialize(N::LLV
10041045 allocate_sret! (B, N)
10051046end
10061047
1007- function emit_error (B:: LLVM.IRBuilder , @nospecialize (orig:: Union{Nothing, LLVM.Instruction} ), string:: Union{String, LLVM.Value} , @nospecialize (errty:: Type ) = EnzymeRuntimeException)
1048+ function emit_error (B:: LLVM.IRBuilder , @nospecialize (orig:: Union{Nothing, LLVM.Instruction} ), string:: Union{String, LLVM.Value} , @nospecialize (errty:: Type ) = EnzymeRuntimeException, @nospecialize (cond :: Union{Nothing, LLVM.Value} ) = nothing )
10081049 curent_bb = position (B)
10091050 fn = LLVM. parent (curent_bb)
10101051 mod = LLVM. parent (fn)
@@ -1061,24 +1102,30 @@ function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.I
10611102 err = emit_allocobj! (B, errty)
10621103 err2 = bitcast! (B, err, LLVM. PointerType (LLVM. PointerType (LLVM. Int8Type ()), 10 ))
10631104 store! (B, string, err2)
1064- emit_jl_throw! (
1065- B,
1066- addrspacecast! (B, err, LLVM. PointerType (LLVM. StructType (LLVMType[]), 12 )),
1067- )
1105+ if cond != = nothing
1106+ emit_conditional_throw! (B, err, cond)
1107+ else
1108+ emit_jl_throw! (
1109+ B,
1110+ addrspacecast! (B, err, LLVM. PointerType (LLVM. StructType (LLVMType[]), 12 )),
1111+ )
1112+ end
10681113 end
10691114
10701115 # 2. Call error function and insert unreachable
1071- LLVM. API. LLVMAddCallSiteAttribute (
1072- ct,
1073- reinterpret (LLVM. API. LLVMAttributeIndex, LLVM. API. LLVMAttributeFunctionIndex),
1074- EnumAttribute (" noreturn" ),
1075- )
1076- if EnzymeMutabilityException != errty
1077- LLVM. API. LLVMAddCallSiteAttribute (
1078- ct,
1079- reinterpret (LLVM. API. LLVMAttributeIndex, LLVM. API. LLVMAttributeFunctionIndex),
1080- StringAttribute (" enzyme_error" ),
1081- )
1116+ if cond === nothing
1117+ LLVM. API. LLVMAddCallSiteAttribute (
1118+ ct,
1119+ reinterpret (LLVM. API. LLVMAttributeIndex, LLVM. API. LLVMAttributeFunctionIndex),
1120+ EnumAttribute (" noreturn" ),
1121+ )
1122+ if EnzymeMutabilityException != errty
1123+ LLVM. API. LLVMAddCallSiteAttribute (
1124+ ct,
1125+ reinterpret (LLVM. API. LLVMAttributeIndex, LLVM. API. LLVMAttributeFunctionIndex),
1126+ StringAttribute (" enzyme_error" ),
1127+ )
1128+ end
10821129 end
10831130 return ct
10841131end
0 commit comments