Skip to content

Commit e8f3cee

Browse files
authored
Lowered return literal (#2639)
* Lowered return literal * with test * mixed test * fix
1 parent f6839ee commit e8f3cee

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/compiler.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3676,7 +3676,12 @@ function lower_convention(
36763676
sret = sret !== nothing
36773677
returnRoots = returnRoots !== nothing
36783678

3679-
loweredReturn = RetActivity <: Active && (actualRetType === Any)
3679+
loweredReturn = RetActivity <: Active && !allocatedinline(actualRetType)
3680+
if (RetActivity <: Active || RetActivity <: MixedDuplicated || RetActivity <: BatchMixedDuplicated) && (allocatedinline(actualRetType) != allocatedinline(eltype(RetActivity)))
3681+
@assert !allocatedinline(actualRetType)
3682+
loweredReturn = true
3683+
end
3684+
36803685
expected_RT = Nothing
36813686
if loweredReturn
36823687
@assert !sret

test/mixed.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,40 @@ end
121121
flattened_unique_values,
122122
Duplicated(thing, dthing))
123123
end
124+
125+
126+
127+
128+
function literalrt(x)
129+
y = Base.inferencebarrier(x * x)
130+
y2 = Base.inferencebarrier(x * x * x)
131+
return (y, y2)
132+
end
133+
134+
@testset "Literal RT mismatch" begin
135+
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(literalrt)}, Active{Tuple{Float64, Float64}}, Active{Float64})
136+
137+
tape, = fwd(Const(literalrt), Active(3.1))
138+
139+
x = 3.1
140+
@test rev(Const(literalrt), Active(3.1), (2.7, 0.2), tape)[1][1] 2 * x * 2.7 + 3 * x * x * 0.2
141+
142+
end
143+
144+
function literalrt_mixed(x)
145+
y = Base.inferencebarrier(x * x)
146+
y2 = Base.inferencebarrier([x * x * x])
147+
return (y, y2)
148+
end
149+
150+
@testset "Mixed Literal RT mismatch" begin
151+
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(literalrt_mixed)}, MixedDuplicated{Tuple{Float64, Vector{Float64}}}, Active{Float64})
152+
153+
tape, prim, shad = fwd(Const(literalrt_mixed), Active(3.1))
154+
155+
shad[][2][1] = 0.2
156+
157+
x = 3.1
158+
@test rev(Const(literalrt_mixed), Active(3.1), (2.7, shad[][2]), tape)[1][1] 2 * x * 2.7 + 3 * x * x * 0.2
159+
end
160+

0 commit comments

Comments
 (0)