Skip to content

Commit 82cc451

Browse files
authored
Abstract is mixed (#1536)
* Abstract is mixed * fix unionall * fix * more fixups
1 parent a889bb6 commit 82cc451

File tree

5 files changed

+158
-48
lines changed

5 files changed

+158
-48
lines changed

src/compiler.jl

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -252,16 +252,30 @@ end
252252
ActivityState(Int(a1) | Int(a2))
253253
end
254254

255-
struct Merger{seen,worldT,justActive,UnionSret}
255+
struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed}
256256
world::worldT
257257
end
258258

259259
@inline element(::Val{T}) where T = T
260260

261-
@inline function (c::Merger{seen,worldT,justActive,UnionSret})(f::Int) where {seen,worldT,justActive,UnionSret}
261+
# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570
262+
@inline function isghostty(ty)
263+
if ty === Union{}
264+
return true
265+
end
266+
if Base.isconcretetype(ty) && !ismutabletype(ty)
267+
if sizeof(ty) == 0
268+
return true
269+
end
270+
# TODO consider struct_to_llvm ?
271+
end
272+
return false
273+
end
274+
275+
@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})(f::Int) where {seen,worldT,justActive,UnionSret,AbstractIsMixed}
262276
T = element(first(seen))
263277

264-
reftype = ismutabletype(T) || T isa UnionAll
278+
reftype = ismutabletype(T) || (T isa UnionAll && !AbstractIsMixed)
265279

266280
if justActive && reftype
267281
return Val(AnyState)
@@ -273,7 +287,7 @@ end
273287
return Val(AnyState)
274288
end
275289

276-
sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret))
290+
sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
277291

278292
if sub == AnyState
279293
Val(AnyState)
@@ -372,24 +386,31 @@ end
372386
end)
373387
end
374388

375-
@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}) where {ST, Seen, justActive, UnionSret}
389+
@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}, ::Val{AbstractIsMixed}) where {ST, Seen, justActive, UnionSret, AbstractIsMixed}
376390
if ST isa Union
377-
return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret))))
391+
return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))))
378392
end
379-
return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret))
393+
return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
380394
end
381395

382-
@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret}
396+
@inline is_vararg_tup(x) = false
397+
@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where T2 = true
398+
399+
@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false), ::Val{AbstractIsMixed}=Val(false))::ActivityState where {ST,T, justActive, UnionSret, AbstractIsMixed}
383400
if T === Any
384-
return DupState
401+
if AbstractIsMixed
402+
return MixedState
403+
else
404+
return DupState
405+
end
385406
end
386407

387408
if T === Union{}
388409
return AnyState
389410
end
390411

391412
if T <: Complex && !(T isa UnionAll)
392-
return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret))
413+
return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
393414
end
394415

395416
if T <: AbstractFloat
@@ -401,10 +422,14 @@ end
401422
return AnyState
402423
end
403424

404-
if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret)) == AnyState
425+
if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) == AnyState
405426
return AnyState
406427
else
407-
return DupState
428+
if AbstractIsMixed && is_vararg_tup(T)
429+
return MixedState
430+
else
431+
return DupState
432+
end
408433
end
409434
end
410435

@@ -434,35 +459,55 @@ end
434459
if T isa UnionAll
435460
aT = Base.argument_datatype(T)
436461
if aT === nothing
437-
return DupState
462+
if AbstractIsMixed
463+
return MixedState
464+
else
465+
return DupState
466+
end
438467
end
439468
if datatype_fieldcount(aT) === nothing
440-
return DupState
469+
if AbstractIsMixed
470+
return MixedState
471+
else
472+
return DupState
473+
end
441474
end
442475
end
443476

444477
if T isa Union
445478
# if sret union, the data is stored in a stack memory location and is therefore
446479
# not unique'd preventing the boxing of the union in the default case
447480
if UnionSret && is_sret_union(T)
448-
return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret))
481+
return active_reg_recur(T, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
449482
else
450483
if justActive
451484
return AnyState
452485
end
453486
if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != AnyState
454-
return DupState
487+
if AbstractIsMixed
488+
return MixedState
489+
else
490+
return DupState
491+
end
455492
end
456493
if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != AnyState
457-
return DupState
494+
if AbstractIsMixed
495+
return MixedState
496+
else
497+
return DupState
498+
end
458499
end
459500
end
460501
return AnyState
461502
end
462503

463504
# if abstract it must be by reference
464505
if Base.isabstracttype(T)
465-
return DupState
506+
if AbstractIsMixed
507+
return MixedState
508+
else
509+
return DupState
510+
end
466511
end
467512

468513
if ismutabletype(T)
@@ -504,7 +549,7 @@ end
504549

505550
seen2 = (Val(nT), seen...)
506551

507-
fty = Merger{seen2,typeof(world),justActive, UnionSret}(world)
552+
fty = Merger{seen2,typeof(world),justActive, UnionSret, AbstractIsMixed}(world)
508553

509554
ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...)
510555

@@ -1158,20 +1203,6 @@ function permit_inlining!(f::LLVM.Function)
11581203
end
11591204
end
11601205

1161-
# From https://github.com/JuliaLang/julia/blob/81813164963f38dcd779d65ecd222fad8d7ed437/src/cgutils.cpp#L570
1162-
@inline function isghostty(ty)
1163-
if ty === Union{}
1164-
return true
1165-
end
1166-
if Base.isconcretetype(ty) && !ismutabletype(ty)
1167-
if sizeof(ty) == 0
1168-
return true
1169-
end
1170-
# TODO consider struct_to_llvm ?
1171-
end
1172-
return false
1173-
end
1174-
11751206
struct Tape{TapeTy,ShadowTy,ResT}
11761207
internal_tape::TapeTy
11771208
shadow_return::ShadowTy

src/rules/jitrules.jl

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing,
9595
end
9696
)
9797
else
98+
mixexpr = if Width == 1
99+
quote
100+
iterate_unwrap_augfwd_mix(Val($reverse), refs, $(primargs[i]), $(shadowargs[i]))
101+
end
102+
else
103+
quote
104+
iterate_unwrap_augfwd_batchmix(Val($reverse), refs, Val($Width), $(primargs[i]), $(shadowargs[i]))
105+
end
106+
end
98107
dupexpr = if Width == 1
99108
quote
100109
iterate_unwrap_augfwd_dup(Val($reverse), refs, $(primargs[i]), $(shadowargs[i]))
@@ -110,8 +119,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing,
110119
if $aref == ActiveState
111120
iterate_unwrap_augfwd_act($(primargs[i])...)
112121
elseif $aref == MixedState
113-
T = $(primtypes[i])
114-
throw(AssertionError("Mixed State of type $T is unsupported in apply iterate"))
122+
$mixexpr
115123
else
116124
$dupexpr
117125
end
@@ -586,6 +594,51 @@ end
586594
end
587595
end
588596

597+
@inline function iterate_unwrap_augfwd_mix(::Val{reverse}, vals, args, dargs0) where reverse
598+
dargs = dargs0[]
599+
ntuple(Val(length(args))) do i
600+
Base.@_inline_meta
601+
arg = args[i]
602+
ty = Core.Typeof(arg)
603+
actreg = active_reg_nothrow(ty, Val(nothing))
604+
if actreg == AnyState
605+
Const(arg)
606+
elseif actreg == ActiveState
607+
Active(arg)
608+
elseif actreg == MixedState
609+
darg = Base.inferencebarrier(dargs[i])
610+
MixedDuplicated(arg, push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty})
611+
else
612+
Duplicated(arg, dargs[i])
613+
end
614+
end
615+
end
616+
617+
@inline function iterate_unwrap_augfwd_batchmix(::Val{reverse}, vals, ::Val{Width}, args, dargs) where {reverse, Width}
618+
ntuple(Val(length(args))) do i
619+
Base.@_inline_meta
620+
arg = args[i]
621+
ty = Core.Typeof(arg)
622+
actreg = active_reg_nothrow(ty, Val(nothing))
623+
if actreg == AnyState
624+
Const(arg)
625+
elseif actreg == ActiveState
626+
Active(arg)
627+
elseif actreg == MixedState
628+
BatchMixedDuplicated(arg, ntuple(Val(Width)) do j
629+
Base.@_inline_meta
630+
darg = Base.inferencebarrier(dargs[j][][i])
631+
push_if_not_ref(Val(reverse), vals, darg, ty)::Base.RefValue{ty}
632+
end)
633+
else
634+
BatchDuplicated(arg, ntuple(Val(Width)) do j
635+
Base.@_inline_meta
636+
dargs[j][][i]
637+
end)
638+
end
639+
end
640+
end
641+
589642
@inline function allFirst(::Val{Width}, res) where Width
590643
ntuple(Val(Width)) do i
591644
Base.@_inline_meta

src/rules/typeunstablerules.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch
1111
shadow_rets_i = Expr[]
1212
aref = Symbol("active_ref_$i")
1313
for w in 1:Width
14-
sref = Symbol("shadow_"*string(i)*"_"*string(w))
14+
sref = Symbol("sub_shadow_"*string(i)*"_"*string(w))
1515
push!(shadow_rets_i, quote
1616
$sref = if $aref == AnyState
1717
$(primargs[i]);
1818
else
1919
if !ActivityTup[$i]
20-
if $aref == DupState || $aref == MixedState
20+
if ($aref == DupState || $aref == MixedState) && $(batchshadowargs[i][w]) === nothing
2121
prim = $(primargs[i])
2222
throw("Error cannot store inactive but differentiable variable $prim into active tuple")
2323
end
@@ -98,7 +98,7 @@ function body_construct_rev(N, Width, primtypes, active_refs, primargs, batchsha
9898
shad = batchshadowargs[i][w]
9999
out = :(if $(Symbol("active_ref_$i")) == MixedState || $(Symbol("active_ref_$i")) == ActiveState
100100
if $shad isa Base.RefValue
101-
$shad[] = recursive_add($shad[], $expr)
101+
$shad[] = recursive_add($shad[], $expr, identity, guaranteed_nonactive)
102102
else
103103
error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad))
104104
end
@@ -248,10 +248,10 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR)
248248
# if any active [e.g. ActiveState / MixedState] data could exist
249249
# err
250250
if !fwd
251-
if !found
251+
if !found_partial
252252
return false
253253
end
254-
act = active_reg_inner(typ, (), world)
254+
act = active_reg_inner(typ_partial, (), world, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true))
255255
if act == MixedState || act == ActiveState
256256
return false
257257
end
@@ -306,7 +306,7 @@ function common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR)
306306
return false
307307
end
308308

309-
function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
309+
function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool
310310
needsShadowP = Ref{UInt8}(0)
311311
needsPrimalP = Ref{UInt8}(0)
312312
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils))
@@ -379,7 +379,7 @@ function common_f_tuple_fwd(offset, B, orig, gutils, normalR, shadowR)
379379
common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR)
380380
end
381381

382-
function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
382+
function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool
383383
needsShadowP = Ref{UInt8}(0)
384384
needsPrimalP = Ref{UInt8}(0)
385385
activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils))
@@ -420,8 +420,8 @@ function common_f_tuple_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
420420

421421
unsafe_store!(tapeR, sret.ref)
422422

423-
return false
424423
end
424+
return false
425425
end
426426

427427
function common_f_tuple_rev(offset, B, orig, gutils, tape)
@@ -474,7 +474,7 @@ function f_tuple_fwd(B, orig, gutils, normalR, shadowR)
474474
common_f_tuple_fwd(1, B, orig, gutils, normalR, shadowR)
475475
end
476476

477-
function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
477+
function f_tuple_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool
478478
common_f_tuple_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR)
479479
end
480480

@@ -487,7 +487,7 @@ function new_structv_fwd(B, orig, gutils, normalR, shadowR)
487487
common_newstructv_fwd(1, B, orig, gutils, normalR, shadowR)
488488
end
489489

490-
function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
490+
function new_structv_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool
491491
common_newstructv_augfwd(1, B, orig, gutils, normalR, shadowR, tapeR)
492492
end
493493

@@ -525,7 +525,7 @@ function new_structt_fwd(B, orig, gutils, normalR, shadowR)
525525
unsafe_store!(shadowR, shadowres.ref)
526526
return false
527527
end
528-
function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
528+
function new_structt_augfwd(B, orig, gutils, normalR, shadowR, tapeR)::Bool
529529
new_structt_fwd(B, orig, gutils, normalR, shadowR)
530530
end
531531

@@ -821,7 +821,7 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}
821821
return nothing
822822
end
823823

824-
function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)
824+
function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR)::Bool
825825
if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
826826
return true
827827
end

test/mixedapplyiter.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,27 @@ end
141141
@test out[] 5562.9996
142142
@test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]])
143143
@test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]])
144-
end
144+
end
145+
146+
struct MyRectilinearGrid5{FT,FZ}
147+
x :: FT
148+
z :: FZ
149+
end
150+
151+
152+
@inline flatten_tuple(a::Tuple) = @inbounds a[2:end]
153+
@inline flatten_tuple(a::Tuple{<:Any}) = tuple() #inner_flatten_tuple(a[1])...)
154+
155+
function myupdate_state!(model)
156+
tupled = Base.inferencebarrier((model,model))
157+
flatten_tuple(tupled)
158+
return nothing
159+
end
160+
161+
@testset "Abstract type allocation" begin
162+
model = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0])
163+
dmodel = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0])
164+
autodiff(Enzyme.Reverse,
165+
myupdate_state!,
166+
MixedDuplicated(model, Ref(dmodel)))
167+
end

0 commit comments

Comments
 (0)