Skip to content

Commit e8c7c5b

Browse files
authored
feat: add ifrt copy api (#1334)
* feat: ifrt copy api (julia bindings) * chore: bump jll
1 parent 7a121b0 commit e8c7c5b

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.10"
90-
Reactant_jll = "0.0.187"
90+
Reactant_jll = "0.0.188"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

src/ConcreteRArray.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@ function Base.copy(x::Union{AbstractConcreteArray,AbstractConcreteNumber})
2323
return fn(x)
2424
end
2525

26+
function Base.copy(X::ConcreteIFRTArray{T,D,S,P}) where {T,D,S,P}
27+
return ConcreteIFRTArray{T,D,S}(Base.copy(X.data), X.shape, X.sharding, X.padding)
28+
end
29+
30+
function Base.copy(X::ConcretePJRTArray)
31+
return Core.Typeof(X)(Base.copy.(X.data), X.shape, X.sharding)
32+
end
33+
34+
function Base.copy(X::ConcreteIFRTNumber)
35+
return Core.Typeof(X)(Base.copy(X.data), X.sharding)
36+
end
37+
38+
function Base.copy(X::ConcretePJRTNumber)
39+
return Core.Typeof(X)(Base.copy.(X.data), X.sharding)
40+
end
41+
2642
# deepcopy
2743
function Base.deepcopy(x::Union{AbstractConcreteArray,AbstractConcreteNumber})
2844
return Base.copy(x)
@@ -117,14 +133,6 @@ function write_to_host_buffer!(data::Array, X::ConcretePJRTArray{T,N}) where {T,
117133
return nothing
118134
end
119135

120-
function Base.copy(X::ConcretePJRTArray)
121-
return Core.Typeof(X)(Base.copy.(X.data), X.shape, X.sharding)
122-
end
123-
124-
function Base.copy(X::ConcretePJRTNumber)
125-
return Core.Typeof(X)(Base.copy.(X.data), X.sharding)
126-
end
127-
128136
function write_to_host_buffer!(data::Array, X::ConcreteIFRTArray{T,N}) where {T,N}
129137
XLA.to_host(X.data, data, X.sharding)
130138
return nothing

src/xla/IFRT/Array.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,11 @@ function copy_arrays_to_device_with_sharding(buffers::Vector{Array}, sharding::S
424424
end
425425
return dst_arrays
426426
end
427+
428+
function Base.copy(b::Array)
429+
GC.@preserve b begin
430+
return Array(
431+
@ccall MLIR.API.mlir_c.ifrt_copy_array(b.buffer::Ptr{Cvoid})::Ptr{Cvoid}
432+
)
433+
end
434+
end

src/xla/IFRT/AsyncArray.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@ function XLA.to_host(array::AsyncArray, data, reactant_sharding)
2525
end
2626

2727
XLA.sharding(x::AsyncArray) = XLA.sharding(x.buffer)
28+
29+
function Base.copy(b::AsyncArray)
30+
Base.wait(b)
31+
return AsyncArray(Base.copy(b.buffer), nothing)
32+
end

0 commit comments

Comments
 (0)