Skip to content

Commit df7dd87

Browse files
authored
Handle non-zero mixed return (#1529)
* Handle non-zero mixed return * improve mixed activity rule errors
1 parent b8f9beb commit df7dd87

File tree

2 files changed

+90
-15
lines changed

2 files changed

+90
-15
lines changed

src/rules/customrules.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils,
207207
return args, activity, (overwritten...,), actives, kwtup
208208
end
209209

210-
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt))
210+
function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi, @nospecialize(RealRt), B)
211211
width = get_width(gutils)
212212
mode = get_mode(gutils)
213213

@@ -246,10 +246,14 @@ function enzyme_custom_setup_ret(gutils::GradientUtils, orig::LLVM.CallInst, mi,
246246
activep = API.DFT_DUP_NONEED
247247
end
248248

249+
249250
if activep == API.DFT_CONSTANT
250251
RT = Const{RealRt}
251252

252-
elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(RealRt, world) )
253+
elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(RealRt, (), world, #=justActive=#Val(true)) == ActiveState)
254+
if active_reg_inner(RealRt, (), world, #=justActive=#Val(false)) == MixedState && B !== nothing
255+
emit_error(B, orig, "Enzyme: Return type $RealRt has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")
256+
end
253257
RT = Active{RealRt}
254258

255259
elseif activep == API.DFT_DUP_ARG
@@ -298,7 +302,7 @@ function enzyme_custom_fwd(B, orig, gutils, normalR, shadowR)
298302

299303
# 2) Create activity, and annotate function spec
300304
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#false, isKWCall)
301-
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
305+
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B)
302306

303307
alloctx = LLVM.IRBuilder()
304308
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils)))
@@ -511,7 +515,7 @@ end
511515

512516
# 2) Create activity, and annotate function spec
513517
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall)
514-
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt)
518+
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt, B)
515519

516520
needsShadowJL = if RT <: Active
517521
false

src/rules/jitrules.jl

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,44 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
311311

312312
MixedTypes = ntuple(i->:($(Symbol("active_ref_$i") == MixedState) ? Ref($(Symbol("type_$i"))) : $(Symbol("type_$i"))), Val(N))
313313

314+
ending = if Width == 1
315+
quote
316+
if active_reg_nothrow(resT, Val(nothing)) == MixedState && !(initShadow isa Base.RefValue)
317+
shadow_return = Ref(initShadow)
318+
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
319+
return ReturnType((origRet, shadow_return, tape))
320+
else
321+
shadow_return = nothing
322+
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
323+
return ReturnType((origRet, initShadow, tape))
324+
end
325+
end
326+
else
327+
expr = :()
328+
shads = Expr[]
329+
for i in 1:Width
330+
if i == 1
331+
expr = quote !(initShadow[$i] isa Base.RefValue) end
332+
else
333+
expr = quote $expr || !(initShadow[$i] isa Base.RefValue) end
334+
end
335+
push!(shads, quote
336+
Ref(initShadow[$i])
337+
end)
338+
end
339+
quote
340+
if active_reg_nothrow(resT, Val(nothing)) == MixedState && ($expr)
341+
shadow_return = ($(shads...),)
342+
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
343+
return ReturnType((origRet, shadow_return..., tape))
344+
else
345+
shadow_return = nothing
346+
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
347+
return ReturnType((origRet, initShadow..., tape))
348+
end
349+
end
350+
end
351+
314352
return quote
315353
$(active_refs...)
316354
args = ($(wrapped...),)
@@ -384,13 +422,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
384422

385423
@assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed
386424

387-
shadow_return = nothing
388-
tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return)
389-
if $Width == 1
390-
return ReturnType((origRet, initShadow, tape))
391-
else
392-
return ReturnType((origRet, initShadow..., tape))
393-
end
425+
$ending
394426
end
395427
end
396428

@@ -411,6 +443,31 @@ end
411443
return body_runtime_generic_augfwd(N, Width, wrapped, primtypes, active_refs)
412444
end
413445

446+
function nonzero_active_data(x::T) where T<: AbstractFloat
447+
return x != zero(T)
448+
end
449+
450+
nonzero_active_data(::T) where T<: Base.RefValue = false
451+
nonzero_active_data(::T) where T<: Array = false
452+
nonzero_active_data(::T) where T<: Ptr = false
453+
454+
function nonzero_active_data(x::T) where T
455+
if guaranteed_const(T)
456+
return false
457+
end
458+
if ismutable(x)
459+
return false
460+
end
461+
462+
for f in fieldnames(T)
463+
xi = getfield(x, f)
464+
if nonzero_active_data(xi)
465+
return true
466+
end
467+
end
468+
return false
469+
end
470+
414471
function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, active_refs)
415472
outs = []
416473
for i in 1:N
@@ -462,6 +519,10 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
462519
false
463520
end
464521

522+
tt = Tuple{$(ElTypes...)}
523+
rt = Core.Compiler.return_type(f, tt)
524+
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)
525+
465526
if any_mixed
466527
ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)}
467528
rtM = Core.Compiler.return_type(runtime_mixed_call, ttM)
@@ -479,16 +540,20 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
479540
Const{typeof(runtime_mixed_call)},
480541
annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width,
481542
ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)
543+
482544
if tape.shadow_return !== nothing
545+
if !(annotation0M <: Active) && nonzero_active_data(($shadowret,))
546+
ET = ($(ElTypes...),)
547+
throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information"))
548+
end
549+
end
550+
if annotation0M <: Active
483551
adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)
484552
else
485553
adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)
486554
end
487555
nothing
488556
else
489-
tt = Tuple{$(ElTypes...)}
490-
rt = Core.Compiler.return_type(f, tt)
491-
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)
492557

493558
annotation = if $Width != 1 && annotation0 <: Duplicated
494559
BatchDuplicated{rt, $Width}
@@ -502,7 +567,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
502567
annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width,
503568
ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)
504569

505-
tup = if tape.shadow_return !== nothing
570+
if tape.shadow_return !== nothing
571+
if !(annotation0 <: Active) && nonzero_active_data(($shadowret,))
572+
ET = ($(ElTypes...),)
573+
throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information"))
574+
end
575+
end
576+
tup = if annotation0 <: Active
506577
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1]
507578
else
508579
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1]

0 commit comments

Comments
 (0)