Skip to content

Commit 80499b0

Browse files
Fix reshape copy (#1323)
* Fix reshape copy * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 12ca833 commit 80499b0

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/TracedRArray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,10 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T2,N}) where {T
747747
return copyto!(dest, Ops.convert(TracedRArray{T,N}, src))
748748
end
749749

750+
function Base.copyto!(dest::AnyTracedRArray, src::AnyTracedRArray)
751+
return copyto!(dest, materialize_traced_array(src))
752+
end
753+
750754
function Base.copyto!(dest::TracedRArray{T,N}, src::Array{T2,N}) where {T,T2,N}
751755
return copyto!(dest, TracedUtils.promote_to(TracedRArray{T2,N}, src))
752756
end

test/basic.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,23 @@ end
10611061
@test Array(x_ra) == x
10621062
end
10631063

1064+
function reshapecopy!(x, y)
1065+
Base.copyto!(x, reshape(y, size(x)))
1066+
return nothing
1067+
end
1068+
@testset "copyto! Reshaped TracedRArray" begin
1069+
x = zeros(3, 4, 5)
1070+
y = collect(reshape(1:60, (3, 20)))
1071+
1072+
xr = Reactant.to_rarray(x)
1073+
yr = Reactant.to_rarray(y)
1074+
1075+
@jit reshapecopy!(xr, yr)
1076+
1077+
reshapecopy!(x, y)
1078+
@test Array(xr) == x
1079+
end
1080+
10641081
@testset "copy(::Broadcast.Broadcasted{ArrayStyle{ConcreteRArray}})" begin
10651082
x_ra = Reactant.to_rarray(ones(4, 4))
10661083
res = copy(Broadcast.broadcasted(-, Broadcast.broadcasted(+, x_ra, 1)))

0 commit comments

Comments
 (0)