Skip to content

Commit 4a546e9

Browse files
add copyto of subarraay (#1405)
* add copyto of subarraay * fix * fix * fix * Update src/TracedRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 6255b2e commit 4a546e9

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

src/TracedRArray.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -726,15 +726,11 @@ end
726726

727727
Base.copyto!(dest::AnyTracedRArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
728728

729-
function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T,N}
730-
dest.mlir_data = src.mlir_data
729+
function Base.copyto!(dest::AnyTracedRArray{T,N}, src::TracedRArray{T,N}) where {T,N}
730+
TracedUtils.set_mlir_data!(dest, src.mlir_data)
731731
return dest
732732
end
733733

734-
function Base.copyto!(dest::TracedRArray, src::AnyTracedRArray)
735-
return copyto!(dest, materialize_traced_array(src))
736-
end
737-
738734
function Base.copyto!(
739735
dest::Reactant.TracedRArray{T},
740736
dstart::Integer,
@@ -747,14 +743,22 @@ function Base.copyto!(
747743
end
748744

749745
function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T2,N}) where {T,T2,N}
750-
return copyto!(dest, Ops.convert(TracedRArray{T,N}, src))
746+
src2 = if T != T2
747+
Ops.convert(TracedRArray{T,N}, src)
748+
else
749+
src
750+
end
751+
TracedUtils.set_mlir_data!(dest, src2.mlir_data)
752+
return dest
751753
end
752754

753-
function Base.copyto!(dest::AnyTracedRArray, src::AnyTracedRArray)
755+
function Base.copyto!(
756+
dest::AnyTracedRArray{T1,N} where {T1}, src::AnyTracedRArray{T2,N} where {T2}
757+
) where {N}
754758
return copyto!(dest, materialize_traced_array(src))
755759
end
756760

757-
function Base.copyto!(dest::TracedRArray{T,N}, src::Array{T2,N}) where {T,T2,N}
761+
function Base.copyto!(dest::AnyTracedRArray{T,N}, src::Array{T2,N}) where {T,T2,N}
758762
return copyto!(dest, TracedUtils.promote_to(TracedRArray{T2,N}, src))
759763
end
760764

0 commit comments

Comments
 (0)