@@ -311,6 +311,44 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
311311
312312 MixedTypes = ntuple (i-> :($ (Symbol (" active_ref_$i " ) == MixedState) ? Ref ($ (Symbol (" type_$i " ))) : $ (Symbol (" type_$i " ))), Val (N))
313313
314+ ending = if Width == 1
315+ quote
316+ if active_reg_nothrow (resT, Val (nothing )) == MixedState && ! (initShadow isa Base. RefValue)
317+ shadow_return = Ref (initShadow)
318+ tape = Tape {typeof(internal_tape), typeof(shadow_return), resT} (internal_tape, shadow_return)
319+ return ReturnType ((origRet, shadow_return, tape))
320+ else
321+ shadow_return = nothing
322+ tape = Tape {typeof(internal_tape), typeof(shadow_return), resT} (internal_tape, shadow_return)
323+ return ReturnType ((origRet, initShadow, tape))
324+ end
325+ end
326+ else
327+ expr = :()
328+ shads = Expr[]
329+ for i in 1 : Width
330+ if i == 1
331+ expr = quote ! (initShadow[$ i] isa Base. RefValue) end
332+ else
333+ expr = quote $ expr || ! (initShadow[$ i] isa Base. RefValue) end
334+ end
335+ push! (shads, quote
336+ Ref (initShadow[$ i])
337+ end )
338+ end
339+ quote
340+ if active_reg_nothrow (resT, Val (nothing )) == MixedState && ($ expr)
341+ shadow_return = ($ (shads... ),)
342+ tape = Tape {typeof(internal_tape), typeof(shadow_return), resT} (internal_tape, shadow_return)
343+ return ReturnType ((origRet, shadow_return... , tape))
344+ else
345+ shadow_return = nothing
346+ tape = Tape {typeof(internal_tape), typeof(shadow_return), resT} (internal_tape, shadow_return)
347+ return ReturnType ((origRet, initShadow... , tape))
348+ end
349+ end
350+ end
351+
314352 return quote
315353 $ (active_refs... )
316354 args = ($ (wrapped... ),)
@@ -384,13 +422,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
384422
385423 @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed
386424
387- shadow_return = nothing
388- tape = Tape {typeof(internal_tape), typeof(shadow_return), resT} (internal_tape, shadow_return)
389- if $ Width == 1
390- return ReturnType ((origRet, initShadow, tape))
391- else
392- return ReturnType ((origRet, initShadow... , tape))
393- end
425+ $ ending
394426 end
395427end
396428
411443 return body_runtime_generic_augfwd (N, Width, wrapped, primtypes, active_refs)
412444end
413445
446+ function nonzero_active_data (x:: T ) where T<: AbstractFloat
447+ return x != zero (T)
448+ end
449+
450+ nonzero_active_data (:: T ) where T<: Base.RefValue = false
451+ nonzero_active_data (:: T ) where T<: Array = false
452+ nonzero_active_data (:: T ) where T<: Ptr = false
453+
454+ function nonzero_active_data (x:: T ) where T
455+ if guaranteed_const (T)
456+ return false
457+ end
458+ if ismutable (x)
459+ return false
460+ end
461+
462+ for f in fieldnames (T)
463+ xi = getfield (x, f)
464+ if nonzero_active_data (xi)
465+ return true
466+ end
467+ end
468+ return false
469+ end
470+
414471function body_runtime_generic_rev (N, Width, wrapped, primttypes, shadowargs, active_refs)
415472 outs = []
416473 for i in 1 : N
@@ -462,6 +519,10 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
462519 false
463520 end
464521
522+ tt = Tuple{$ (ElTypes... )}
523+ rt = Core. Compiler. return_type (f, tt)
524+ annotation0 = guess_activity (rt, API. DEM_ReverseModePrimal)
525+
465526 if any_mixed
466527 ttM = Tuple{Val{active_refs}, FT, $ (ElTypes... )}
467528 rtM = Core. Compiler. return_type (runtime_mixed_call, ttM)
@@ -479,16 +540,20 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
479540 Const{typeof (runtime_mixed_call)},
480541 annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $ (Types... )}, Val (API. DEM_ReverseModePrimal), width,
481542 ModifiedBetweenM, #= returnPrimal=# Val (true ), #= shadowInit=# Val (false ), FFIABI)
543+
482544 if tape. shadow_return != = nothing
545+ if ! (annotation0M <: Active ) && nonzero_active_data (($ shadowret,))
546+ ET = ($ (ElTypes... ),)
547+ throw (AssertionError (" Shadow value " * string (($ shadowret,))* " returned from type unstable call to $f ($(ET... ) ) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information" ))
548+ end
549+ end
550+ if annotation0M <: Active
483551 adjoint (Const (runtime_mixed_call), Const (Val (active_refs)), dupClosure0 ? Duplicated (f, df) : Const (f), args... , $ shadowret, tape. internal_tape)
484552 else
485553 adjoint (Const (runtime_mixed_call), Const (Val (active_refs)), dupClosure0 ? Duplicated (f, df) : Const (f), args... , tape. internal_tape)
486554 end
487555 nothing
488556 else
489- tt = Tuple{$ (ElTypes... )}
490- rt = Core. Compiler. return_type (f, tt)
491- annotation0 = guess_activity (rt, API. DEM_ReverseModePrimal)
492557
493558 annotation = if $ Width != 1 && annotation0 <: Duplicated
494559 BatchDuplicated{rt, $ Width}
@@ -502,7 +567,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
502567 annotation, Tuple{$ (Types... )}, Val (API. DEM_ReverseModePrimal), width,
503568 ModifiedBetween, #= returnPrimal=# Val (true ), #= shadowInit=# Val (false ), FFIABI)
504569
505- tup = if tape. shadow_return != = nothing
570+ if tape. shadow_return != = nothing
571+ if ! (annotation0 <: Active ) && nonzero_active_data (($ shadowret,))
572+ ET = ($ (ElTypes... ),)
573+ throw (AssertionError (" Shadow value " * string (($ shadowret,))* " returned from type unstable call to $f ($(ET... ) ) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information" ))
574+ end
575+ end
576+ tup = if annotation0 <: Active
506577 adjoint (dupClosure0 ? Duplicated (f, df) : Const (f), args... , $ shadowret, tape. internal_tape)[1 ]
507578 else
508579 adjoint (dupClosure0 ? Duplicated (f, df) : Const (f), args... , tape. internal_tape)[1 ]
0 commit comments