Skip to content

Commit 02d0455

Browse files
Fast pjrt copy (#1327)
* Fast pjrt copy * number * actually copy * Update src/xla/PJRT/AsyncBuffer.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent eabfc37 commit 02d0455

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

src/ConcreteRArray.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ end
2525

2626
# deepcopy
2727
function Base.deepcopy(x::Union{AbstractConcreteArray,AbstractConcreteNumber})
28-
fn = Reactant.compile(copy, (x,))
29-
return fn(x)
28+
Base.copy(x)
3029
end
3130

3231
# One more reason why users shouldn't call `deepcopy`
@@ -118,6 +117,14 @@ function write_to_host_buffer!(data::Array, X::ConcretePJRTArray{T,N}) where {T,
118117
return nothing
119118
end
120119

120+
function Base.copy(X::ConcretePJRTArray)
121+
Core.Typeof(X)(Base.copy.(X.data), X.shape, X.sharding)
122+
end
123+
124+
function Base.copy(X::ConcretePJRTNumber)
125+
Core.Typeof(X)(Base.copy.(X.data), X.sharding)
126+
end
127+
121128
function write_to_host_buffer!(data::Array, X::ConcreteIFRTArray{T,N}) where {T,N}
122129
XLA.to_host(X.data, data, X.sharding)
123130
return nothing

src/xla/PJRT/AsyncBuffer.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@ end
66
const AsyncEmptyBuffer = AsyncBuffer(Buffer(C_NULL), nothing)
77

88
AsyncBuffer(args...; kwargs...) = AsyncBuffer(Buffer(args...; kwargs...), nothing)
9+
10+
function Base.copy(b::AsyncBuffer)
11+
Base.wait(b)
12+
return AsyncBuffer(Base.copy(b.buffer), nothing)
13+
end

src/xla/PJRT/Buffer.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,17 @@ function XLA.unsafe_buffer_pointer(buffer::Buffer)
8787
@ccall MLIR.API.mlir_c.UnsafeBufferPointer(buffer.buffer::Ptr{Cvoid})::Ptr{Cvoid}
8888
end
8989

90+
function Base.copy(buffer::Buffer)
91+
dev = XLA.device(buffer)
92+
GC.@preserve buffer dev begin
93+
Buffer(
94+
@ccall MLIR.API.mlir_c.CopyBufferToDevice(
95+
buffer.buffer::Ptr{Cvoid}, dev.device::Ptr{Cvoid}
96+
)::Ptr{Cvoid}
97+
)
98+
end
99+
end
100+
90101
function XLA.copy_buffer_to_device(buffer::Buffer, dev::Device)
91102
XLA.device(buffer) == dev && return buffer
92103
GC.@preserve buffer dev begin

0 commit comments

Comments
 (0)