Skip to content

Commit 6c2b0d9

Browse files
authored
Improve rule arg mixed errors (#1530)
* Improve rule arg mixed errors * fixup * improve errs
1 parent df7dd87 commit 6c2b0d9

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

src/rules/customrules.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,14 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
122122

123123
push!(activity, Ty)
124124

125-
elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(arg.typ, world) )
125+
elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world, #=justActive=#Val(true)) == ActiveState)
126126
Ty = Active{arg.typ}
127127
llty = convert(LLVMType, Ty)
128128
arty = convert(LLVMType, arg.typ; allow_boxed=true)
129129
if B !== nothing
130+
if active_reg_inner(arg.typ, (), world, #=justActive=#Val(false)) == MixedState
131+
emit_error(B, orig, "Enzyme: Argument type $(arg.typ) has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")
132+
end
130133
al0 = al = emit_allocobj!(B, Ty)
131134
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
132135
al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived))
@@ -716,6 +719,17 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
716719
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4]))
717720
innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)])
718721
if innerTy != value_type(tape)
722+
if isabstracttype(TapeT)
723+
msg = sprint() do io
724+
println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))")
725+
println(io, "tape_idx=", tape_idx)
726+
println(io, "sret=", sret)
727+
println(io, "RT=", RT)
728+
println(io, "tape=", tape)
729+
println(io, "llvmf=", string(llvmf))
730+
end
731+
throw(AssertionError(msg))
732+
end
719733
llty = convert(LLVMType, TapeT; allow_boxed=true)
720734
al0 = al = emit_allocobj!(B, TapeT)
721735
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))

src/rules/jitrules.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,7 @@ for (N, Width) in Iterators.product(0:30, 1:10)
11111111
eval(func_runtime_iterate_rev(N, Width))
11121112
end
11131113

1114-
function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true)
1114+
function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true)
11151115
width = get_width(gutils)
11161116
mode = get_mode(gutils)
11171117
mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))
@@ -1132,7 +1132,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder,
11321132
T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
11331133
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
11341134

1135-
if firstconst
1135+
if firstconst && !firstconst_after_tape
11361136
val = new_from_original(gutils, operands(orig)[start])
11371137
if lookup
11381138
val = lookup_value(gutils, val, B)
@@ -1196,6 +1196,14 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder,
11961196
else
11971197
pushfirst!(vals, unsafe_to_llvm(Val(ReturnType)))
11981198
end
1199+
1200+
if firstconst && firstconst_after_tape
1201+
val = new_from_original(gutils, operands(orig)[start])
1202+
if lookup
1203+
val = lookup_value(gutils, val, B)
1204+
end
1205+
pushfirst!(vals, val)
1206+
end
11991207

12001208
if mode != API.DEM_ForwardMode
12011209
uncacheable = get_uncacheable(gutils, orig)

src/rules/typeunstablerules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,19 +181,19 @@ function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primarg
181181
end
182182

183183
function func_runtime_newstruct_augfwd(N, Width)
184-
primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width)
184+
primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true)
185185
body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs)
186186

187187
quote
188-
function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)}
188+
function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)}
189189
$body
190190
end
191191
end
192192
end
193193

194-
@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType}
194+
@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType}
195195
N = div(length(allargs)+2, Width+1)-1
196-
primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs)
196+
primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true)
197197
return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs)
198198
end
199199

@@ -325,7 +325,7 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap
325325

326326
width = get_width(gutils)
327327

328-
sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false)
328+
sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true)
329329

330330
if width == 1
331331
shadow = sret
@@ -369,7 +369,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape)
369369
if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing)
370370
@assert tape !== C_NULL
371371
width = get_width(gutils)
372-
generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape)
372+
generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true)
373373
end
374374

375375
return nothing

0 commit comments

Comments
 (0)