Skip to content

Commit 1c582f8

Browse files
committed
add some more uses of enzyme_context to access the world
1 parent 1b74dc8 commit 1c582f8

File tree

5 files changed

+17
-0
lines changed

5 files changed

+17
-0
lines changed

src/rules/activityrules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ function julia_activity_rule(f::LLVM.Function)
3131
return
3232
end
3333
world = enzyme_extract_world(f)
34+
# TODO: Access to gutils
3435

3536
# TODO fix the attributor inlining such that this can assert always true
3637
if expectLen != length(parameters(f))

src/rules/customrules.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ function enzyme_custom_setup_args(
145145

146146
ofn = LLVM.parent(LLVM.parent(orig))
147147
world = enzyme_extract_world(ofn)
148+
@assert world == enzyme_context(gutils).world
148149

149150
for arg in jlargs
150151
@assert arg.cc != RemovedParam
@@ -454,6 +455,7 @@ function enzyme_custom_setup_ret(
454455
mode = get_mode(gutils)
455456

456457
world = enzyme_extract_world(LLVM.parent(LLVM.parent(orig)))
458+
@assert world == enzyme_context(gutils).world
457459

458460
needsShadowP = Ref{UInt8}(0)
459461
needsPrimalP = Ref{UInt8}(0)
@@ -577,6 +579,7 @@ end
577579
curent_bb = position(B)
578580
fn = LLVM.parent(curent_bb)
579581
world = enzyme_extract_world(fn)
582+
@assert world == enzyme_context(gutils).world
580583

581584
llvmf = nested_codegen!(mode, mod, fmi, world)
582585

@@ -812,6 +815,7 @@ end
812815

813816
fn = LLVM.parent(LLVM.parent(orig))
814817
world = enzyme_extract_world(fn)
818+
@assert world == enzyme_context(gutils).world
815819

816820
C = EnzymeRules.RevConfig{
817821
Bool(needsPrimal),
@@ -925,6 +929,7 @@ end
925929

926930
fn = LLVM.parent(LLVM.parent(orig))
927931
world = enzyme_extract_world(fn)
932+
@assert world == enzyme_context(gutils).world
928933
@safe_debug "Trying to apply custom forward rule" TT isKWCall
929934

930935
functy = if isKWCall
@@ -1035,6 +1040,7 @@ function enzyme_custom_common_rev(
10351040
curent_bb = position(B)
10361041
fn = LLVM.parent(curent_bb)
10371042
world = enzyme_extract_world(fn)
1043+
@assert world == enzyme_context(gutils).world
10381044

10391045
mode = get_mode(gutils)
10401046

src/rules/llvmrules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,7 @@ end
19321932

19331933
fn = LLVM.parent(LLVM.parent(orig))
19341934
world = enzyme_extract_world(fn)
1935+
@assert world == enzyme_context(gutils).world
19351936
if !guaranteed_nonactive(ET, world)
19361937
emit_error(B, orig, "Enzyme: element type $ET of generic_memory_copyto is potentially active ($reg) and not presently supported")
19371938
end

src/rules/parallelrules.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ end
227227
modifiedBetween = (mode != API.DEM_ForwardMode, false)
228228

229229
world = enzyme_extract_world(LLVM.parent(position(B)))
230+
@assert world == enzyme_context(gutils).world
230231

231232
pfuncT = funcT
232233

@@ -550,6 +551,7 @@ end
550551
tt = Tuple{thunkTy,dfuncT,Bool}
551552
mode = get_mode(gutils)
552553
world = enzyme_extract_world(LLVM.parent(position(B)))
554+
@assert world == enzyme_context(gutils).world
553555
entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt, world)
554556
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
555557

@@ -594,6 +596,7 @@ end
594596
}
595597
mode = get_mode(gutils)
596598
world = enzyme_extract_world(LLVM.parent(position(B)))
599+
@assert world == enzyme_context(gutils).world
597600
entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt, world)
598601
push!(function_attributes(entry), EnumAttribute("alwaysinline"))
599602

@@ -627,6 +630,7 @@ end
627630
@register_rev function threadsfor_rev(B, orig, gutils, tape)
628631
mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
629632
world = enzyme_extract_world(LLVM.parent(position(B)))
633+
@assert world == enzyme_context(gutils).world
630634
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
631635
return
632636
end
@@ -675,6 +679,7 @@ end
675679
mode = get_mode(gutils)
676680

677681
world = enzyme_extract_world(LLVM.parent(position(B)))
682+
@assert world == enzyme_context(gutils).world
678683

679684
ops = collect(operands(orig))
680685

@@ -731,6 +736,7 @@ end
731736
ModifiedBetween = (uncacheable[1] != 0,)
732737

733738
world = enzyme_extract_world(LLVM.parent(position(B)))
739+
@assert world == enzyme_context(gutils).world
734740

735741
ops = collect(operands(orig))
736742

src/rules/typeunstablerules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR)
442442
width = get_width(gutils)
443443

444444
world = enzyme_extract_world(LLVM.parent(position(B)))
445+
@assert world == enzyme_context(gutils).world
445446

446447
@assert is_constant_value(gutils, origops[offset])
447448
icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]]
@@ -932,6 +933,7 @@ end
932933
else
933934
@assert legal
934935
world = enzyme_extract_world(LLVM.parent(position(B)))
936+
@assert world == enzyme_context(gutils).world
935937
if !guaranteed_nonactive(TT, world)
936938
unsafe_store!(tapeR, shadowres.ref)
937939
end
@@ -1034,6 +1036,7 @@ end
10341036
if legal
10351037
@assert legal
10361038
world = enzyme_extract_world(LLVM.parent(position(B)))
1039+
@assert world == enzyme_context(gutils).world
10371040
torun = !guaranteed_nonactive(TT, world)
10381041
else
10391042
torun = true

0 commit comments

Comments
 (0)