@@ -1125,7 +1125,11 @@ struct Return2
11251125end
11261126
11271127function force_recompute! (mod:: LLVM.Module )
1128- for f in functions (mod), bb in blocks (f), inst in collect (instructions (bb))
1128+ for f in functions (mod), bb in blocks (f)
1129+ iter = LLVM. API. LLVMGetFirstInstruction (bb)
1130+ while iter != C_NULL
1131+ inst = LLVM. Instruction (iter)
1132+ iter = LLVM. API. LLVMGetNextInstruction (iter)
11291133 if isa (inst, LLVM. LoadInst)
11301134 has_loaded = false
11311135 for u in LLVM. uses (inst)
@@ -1170,6 +1174,7 @@ function force_recompute!(mod::LLVM.Module)
11701174 end
11711175 end
11721176 end
1177+ end
11731178end
11741179
11751180function permit_inlining! (f:: LLVM.Function )
@@ -3275,7 +3280,7 @@ end
32753280# Enzyme compiler step
32763281# #
32773282
3278- function annotate! (mod, mode )
3283+ function annotate! (mod:: LLVM.Module )
32793284 inactive = LLVM. StringAttribute (" enzyme_inactive" , " " )
32803285 active = LLVM. StringAttribute (" enzyme_active" , " " )
32813286 no_escaping_alloc = LLVM. StringAttribute (" enzyme_no_escaping_allocation" )
@@ -3891,7 +3896,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt
38913896 throw (AssertionError (" Enzyme: could not find world in $(string (fn)) " ))
38923897end
38933898
3894- function enzyme_custom_extract_mi (orig:: LLVM.Instruction , error:: Bool = true )
3899+ function enzyme_custom_extract_mi (orig:: LLVM.CallInst , error:: Bool = true )
38953900 operand = LLVM. called_operand (orig)
38963901 if isa (operand, LLVM. Function)
38973902 return enzyme_custom_extract_mi (operand:: LLVM.Function , error)
@@ -6144,7 +6149,7 @@ end
61446149
61456150using Random
61466151# returns arg, return
6147- function no_type_setting (@nospecialize (specTypes); world = nothing )
6152+ function no_type_setting (@nospecialize (specTypes:: Type{<:Tuple} ); world = nothing )
61486153 # Even though the julia type here is ptr{int8}, the actual data can be something else
61496154 if specTypes. parameters[1 ] == typeof (Random. XoshiroSimd. xoshiro_bulk_simd)
61506155 return (true , false )
@@ -7037,7 +7042,7 @@ end
70377042 end
70387043
70397044 # annotate
7040- annotate! (mod, mode )
7045+ annotate! (mod)
70417046 for name in (" gpu_report_exception" , " report_exception" )
70427047 if haskey (functions (mod), name)
70437048 exc = functions (mod)[name]
@@ -8012,9 +8017,6 @@ end
80128017 :: Type{TapeType} ,
80138018 args:: Vararg{Any,N} ,
80148019) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal}
8015-
8016- JuliaContext () do ctx
8017- Base. @_inline_meta
80188020 F = eltype (FA)
80198021 is_forward =
80208022 CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk
@@ -8263,6 +8265,10 @@ end
82638265 i += 1
82648266 end
82658267
8268+ ts_ctx = JuliaContext ()
8269+ ctx = context (ts_ctx)
8270+ activate (ctx)
8271+ (ir, fn, combinedReturn) = try
82668272
82678273 if is_adjoint
82688274 NT = Tuple{ActiveRetTypes... }
@@ -8441,31 +8447,35 @@ end
84418447
84428448 ir = string (mod)
84438449 fn = LLVM. name (llvm_f)
8450+ (ir, fn, combinedReturn)
8451+ finally
8452+ deactivate (ctx)
8453+ dispose (ts_ctx)
8454+ end
84448455
8445- @assert length (types) == length (ccexprs)
8456+ @assert length (types) == length (ccexprs)
84468457
84478458
8448- if ! (GPUCompiler. isghosttype (PT) || Core. Compiler. isconstType (PT))
8449- return quote
8450- Base. @_inline_meta
8451- Base. llvmcall (
8452- ($ ir, $ fn),
8453- $ combinedReturn,
8454- Tuple{$ PT,$ (types... )},
8455- fptr,
8456- $ (ccexprs... ),
8457- )
8458- end
8459- else
8460- return quote
8461- Base. @_inline_meta
8462- Base. llvmcall (
8463- ($ ir, $ fn),
8464- $ combinedReturn,
8465- Tuple{$ (types... )},
8466- $ (ccexprs... ),
8467- )
8468- end
8459+ if ! (GPUCompiler. isghosttype (PT) || Core. Compiler. isconstType (PT))
8460+ return quote
8461+ Base. @_inline_meta
8462+ Base. llvmcall (
8463+ ($ ir, $ fn),
8464+ $ combinedReturn,
8465+ Tuple{$ PT,$ (types... )},
8466+ fptr,
8467+ $ (ccexprs... ),
8468+ )
8469+ end
8470+ else
8471+ return quote
8472+ Base. @_inline_meta
8473+ Base. llvmcall (
8474+ ($ ir, $ fn),
8475+ $ combinedReturn,
8476+ Tuple{$ (types... )},
8477+ $ (ccexprs... ),
8478+ )
84698479 end
84708480 end
84718481end
@@ -9071,7 +9081,10 @@ include("compiler/reflection.jl")
90719081 )
90729082 copysetfn = meta. entry
90739083 blk = first (blocks (copysetfn))
9074- for inst in collect (instructions (blk))
9084+ iter = LLVM. API. LLVMGetFirstInstruction (blk)
9085+ while iter != C_NULL
9086+ inst = LLVM. Instruction (iter)
9087+ iter = LLVM. API. LLVMGetNextInstruction (iter)
90759088 if isa (inst, LLVM. FenceInst)
90769089 eraseInst (blk, inst)
90779090 end
0 commit comments