Skip to content

Commit 15e1a93

Browse files
authored
Different-typed forward grad (#2574)
* Different-typed forward grad * fix
1 parent 3b76eb6 commit 15e1a93

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/sugar.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ end
168168
@inline function onehot(x::NTuple{N,T}) where {T,N}
169169
onehot(NTuple{N,T})
170170
end
171-
@inline function onehot(x::NTuple{N,T}, start, endl) where {T,N}
171+
@inline function onehot(x::NTuple{N,T}, start::Int, endl::Int) where {T,N}
172172
ntuple(Val(endl - start + 1)) do i
173173
Base.@_inline_meta
174174
ntuple(Val(N)) do idx
@@ -182,6 +182,17 @@ end
182182
return (one(x),)
183183
end
184184

185+
@inline function onehot(x::Tuple{Vararg{<:AbstractFloat}})
186+
ntuple(Val(length(x))) do i
187+
Base.@_inline_meta
188+
ntuple(Val(length(x))) do idx
189+
Base.@_inline_meta
190+
T = typeof(x[idx])
191+
return (i == idx) ? T(1) : T(0)
192+
end
193+
end
194+
end
195+
185196
"""
186197
gradient(::ReverseMode, f, args...)
187198

test/sugar.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,14 @@ end
191191

192192
end
193193

194+
tupsq(t) = t[1]*t[2]
195+
196+
@testset "Forward differing element tuple of floats" begin
197+
res = Enzyme.gradient(Enzyme.Forward, tupsq, (3.0, Float32(2.0)))[1]
198+
@test res[1] 2.0
199+
@test res[2] 3.0
200+
end
201+
194202
# these are used in gradient and jacobian tests
195203
struct InpStruct
196204
i1::Float64

0 commit comments

Comments
 (0)