Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"

[sources]
ReactantCore = {path = "lib/ReactantCore"}
Expand All @@ -57,6 +58,7 @@ ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
ReactantDLFP8TypesExt = "DLFP8Types"
ReactantFillArraysExt = "FillArrays"
ReactantFixedSizeArraysExt = "FixedSizeArrays"
ReactantFloat8sExt = "Float8s"
ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantMPIExt = "MPI"
Expand All @@ -82,6 +84,7 @@ EnumX = "1"
Enzyme = "0.13.74"
EnzymeCore = "0.8.13"
FillArrays = "1.13"
FixedSizeArrays = "1.2.0"
Float8s = "0.1"
Functors = "0.5"
GPUArraysCore = "0.2"
Expand Down
36 changes: 36 additions & 0 deletions ext/ReactantFixedSizeArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module ReactantFixedSizeArraysExt

using FixedSizeArrays
using Reactant
using Reactant: TracedRArray, TracedRNumber, Ops
using ReactantCore: ReactantCore
Comment on lines +3 to +6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using FixedSizeArrays
using Reactant
using Reactant: TracedRArray, TracedRNumber, Ops
using ReactantCore: ReactantCore
using FixedSizeArrays: FixedSizeArrayDefault
using Reactant: Reactant, TracedRArray, TracedRNumber, Ops
using ReactantCore: ReactantCore


function Reactant.traced_type_inner(
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}),
@nospecialize(_::Type{FixedSizeArrayDefault{T,N}}),

seen,
@nospecialize(mode::Reactant.TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
) where {T,N}
T2 = Reactant.TracedRNumber{T}
return FixedSizeArrays.FixedSizeArrayDefault{T2,N}
end

Base.@nospecializeinfer function Reactant.make_tracer(
seen,
@nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}),
@nospecialize(prev::FixedSizeArrayDefault{T,N}),

@nospecialize(path),
mode;
kwargs...,
) where {T,N}
shape = size(prev)
return reshape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this reshape is the culprit? How does Array work? Is it possible to construct this object directly with the right size instead of reshaping it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but array itself we override to keep the dimensionality and allocate everything ourselves for Array.

if the actual tracing of a fixedsizearray does the current "generic recursion into structs" it will eventually allocate a 1-dim memory, always

Copy link
Member

@giordano giordano Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point of FixedSizeArray is that the size is...well...fixed. Having to reshape it all the time seems to go into the opposite direction, especially when Array doesn't have that.

Copy link
Member

@giordano giordano Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels to me like make_tracer should take the size as an (optional) argument. Looking at

Reactant.jl/src/Tracing.jl

Lines 1177 to 1196 in e4bb34f

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::ConcreteIFRTArray{T,N}),
@nospecialize(path),
mode;
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(device = nothing),
@nospecialize(client = nothing),
kwargs...,
) where {T,N}
if mode == TracedToTypes
throw("Cannot have ConcreteIFRTArray as function call argument.")
end
mode == ArrayToConcrete && return ConcreteIFRTArray(prev; sharding, device, client)
mode != ConcreteToTraced && throw("Cannot trace concrete")
haskey(seen, prev) && return seen[prev]::TracedRArray{T,N}
res = TracedRArray{T,N}((path,), nothing, size(prev))
seen[prev] = res
return res
end
(and all similar methods) the size could be another argument which defaults to size(prev) but could be overridden if passed explicitly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #1712 to implement my suggestion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm missing something fundamental: Array has the same memory backend, why should FixedSizeArray be any different?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array maps to our Concrete/Traced Array. For all "wrapped" types, we need to preserve the wrapper type, if we want the outputs to preserve the wrapped type. Else any operation you perform on a FixedSizeArray with inevitably be mapped to an Array output

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really convinced a proposal for changing the memory backend in FixedSizeArrays is going to fly: it's meant to follow Array very closely, and the memory backend is always a flattened dense vector.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean once #1696 goes in, I think we should be able to make that work (though the backing memory might need to be a reshape(tracedrarray{n}, 1), but the reshape will optimize out with the one that's emitted currently

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReshapedArray unfortunately is a AbstractArray and not a DenseArray

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing the call to FixedSizedArray

Copy link
Author

@Qfl3x Qfl3x Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning a FixedSizeArray causes problems.

With track_numbers=Number I get this: (@jit(f(rx)))

ERROR: TypeError: in RNumber, in T, expected T<:Union{Complex{Int16}, Complex{Int32}, Complex{Int64}, Complex{Int8}, Complex{UInt16}, Complex{UInt32}, Complex{UInt64}, Complex{UInt8}, Core.BFloat16, Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, Reactant.F8E4M3B11FNUZ, Reactant.F8E4M3FN, Reactant.F8E4M3FNUZ, Reactant.F8E5M2, Reactant.F8E5M2FNUZ, ComplexF64, ComplexF32}, got Type{Reactant.TracedRNumber{Float32}}
Stacktrace:
  [1] copyto!(dest::Memory{Reactant.TracedRNumber{Float32}}, src::Vector{Reactant.TracedRNumber{Float32}})
    @ Reactant.TracedRArrayOverrides ~/projects/Reactant.jl/src/TracedRArray.jl:492
  [2] collect_as_vectorlike_with_known_eltype_and_length(::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:269
  [3] collect_as_memory_with_known_eltype_and_known_length(::Type{Reactant.TracedRNumber{Float32}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:345
  [4] collect_as_memory_with_known_eltype(::Type{Reactant.TracedRNumber{Float32}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:353
  [5] collect_as_memory(::Function, ::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:373
  [6] collect_as_common_invariant(e::typeof(Collects.EmptyIteratorHandling.just_throws), ::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:376
  [7] collect_as_common(e::typeof(Collects.EmptyIteratorHandling.just_throws), type::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:400
  [8] (::Collects.Collect{typeof(Collects.EmptyIteratorHandling.just_throws)})(type::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:411
  [9] (::Collects.Collect{typeof(Collects.EmptyIteratorHandling.just_throws)})(::Type{FixedSizeArrayDefault{Reactant.TracedRNumber{Float32}}}, iterator::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/collect_as.jl:45
 [10] collect_as(::Type{FixedSizeArrayDefault{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}}; empty_iterator_handler::typeof(Collects.EmptyIteratorHandling.just_throws))
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:425
 [11] collect_as(::Type{FixedSizeArrayDefault{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:423
 [12] collect_as_haseltype(::Type{FixedSizeArrayDefault}, iterator::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/collect_as.jl:56
 [13] converting_constructor(::Type{FixedSizeArrayDefault}, src::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/FixedSizeArray.jl:334
 [14] (FixedSizeArrayDefault)(src::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/FixedSizeArray.jl:345
 [15] make_tracer(seen::Reactant.OrderedIdDict{Any, Any}, prev::FixedSizeArray{Float32, 1, Memory{Float32}}, path::Any, mode::Reactant.TraceMode; kwargs::@Kwargs{runtime::Val{:PJRT}})
    @ ReactantFixedSizeArraysExt ~/projects/Reactant.jl/ext/ReactantFixedSizeArraysExt.jl:33

I don't know why it's throwing this, as after inspection the inner Memory object is traced properly.

Without it I get:

ERROR: cannot copy Ptr{Nothing} @0x0000795f973dc1e0 of type Ptr{Nothing}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] create_result(tocopy::Ptr{…}, path::Tuple{…}, result_stores::Dict{…}, path_to_shard_info::Nothing, to_unreshard_results::Dict{…}, unresharded_code::Vector{…}, unresharded_arrays_cache::Dict{…}, used_shardinfo::Set{…}, result_cache::IdDict{…}, var_idx::Base.RefValue{…}, resultgen_code::Vector{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:231
 [3] create_result(tocopy::Memory{…}, path::Tuple{…}, result_stores::Dict{…}, path_to_shard_info::Nothing, to_unreshard_results::Dict{…}, unresharded_code::Vector{…}, unresharded_arrays_cache::Dict{…}, used_shardinfo::Set{…}, result_cache::IdDict{…}, var_idx::Base.RefValue{…}, resultgen_code::Vector{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:256
 [4] create_result(tocopy::FixedSizeArray{…}, path::Tuple{}, result_stores::Dict{…}, path_to_shard_info::Nothing, to_unreshard_results::Dict{…}, unresharded_code::Vector{…}, unresharded_arrays_cache::Dict{…}, used_shardinfo::Set{…}, result_cache::IdDict{…}, var_idx::Base.RefValue{…}, resultgen_code::Vector{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:256
 [5] codegen_unflatten!(linear_args::Vector{…}, preserved_args::Vector{…}, concretized_res_names::Vector{…}, linear_results::Vector{…}, concrete_result::FixedSizeArray{…}, result_stores::Dict{…}, path_to_shard_info::Nothing, linear_result_shard_info::Vector{…}, client::Reactant.XLA.PJRT.Client, resharded_inputs::Dict{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:3147
 [6] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:3599
 [7] top-level scope
   @ ~/projects/Reactant.jl/src/Compiler.jl:2614

For this one I don't have a clue.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So:

mem = Memory{Float32}([1.0f0, 2.0f0])
getfield(mem, 2)
Ptr{nothing} <pointer address>

This is where the pointer is coming from.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Reactant isn't supporting resizing for normal arrays, and FixedSizeArray isn't doing any fancy memory/optimization stuff other than fixed size (unlike OneHotArrays, FillArrays) shouldn't it be fine for the FixedSizeArray to transform into an ordinary Concrete Array?

Reactant.make_tracer(
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
),
shape,
)
end

end
91 changes: 91 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,97 @@ Base.@nospecializeinfer function make_tracer(
return res
end

if isdefined(Base, :Memory)
Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Memory),
@nospecialize(path),
mode;
@nospecialize(track_numbers::Type = Union{}),
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(runtime = nothing),
@nospecialize(device = nothing),
@nospecialize(client = nothing),
kwargs...,
)
RT = Core.Typeof(prev)
# XXX: If someone wants to shard the same array with different shardings, we need to
# somehow handle this correctly... Right now we just use the first sharding.
if mode != NoStopTracedTrack && haskey(seen, prev)
if mode == TracedToTypes
visited = seen[prev]
push!(path, visited)
return nothing
end
return seen[prev]
end
if eltype(RT) <: ReactantPrimitive
if mode == ArrayToConcrete
runtime isa Val{:PJRT} &&
(return seen[prev] = ConcretePJRTArray(prev; sharding, device, client))
runtime isa Val{:IFRT} &&
(return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client))
error("Unsupported runtime $runtime")
elseif mode == TracedToTypes
# Original array can get mutated so we store a copy:
push!(path, copy(prev))
seen[prev] = VisitedObject(length(seen) + 1)
return nothing
end
elseif mode == TracedToTypes
push!(path, RT)
for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
make_tracer(
seen,
pv,
path,
mode;
track_numbers,
sharding,
runtime,
device,
client,
kwargs...,
)
end
end
return nothing
end
TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime)
newa = Array{TT,ndims(RT)}(undef, size(prev))
seen[prev] = newa
same = true
for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
nv = make_tracer(
seen,
pv,
append_path(path, I),
mode;
track_numbers,
sharding=Base.getproperty(sharding, I),
runtime,
device,
client,
kwargs...,
)
if pv !== nv
same = false
end
@inbounds newa[I] = nv
end
end
if same
seen[prev] = prev
return prev
end
return newa
end
end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Sharding.Mesh),
Expand Down
40 changes: 38 additions & 2 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,24 @@ function ConcretePJRTArray(
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
end

if isdefined(Base, :Memory)
function ConcretePJRTArray(
data::Memory{T};
client::Union{Nothing,XLA.PJRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.PJRT.Device}=nothing,
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
) where {T}
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
sharded_data, shardinfo = sharding(theclient, thedevice, data)
shape = size(data)
nsharded = length(sharded_data)
return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(
sharded_data, shape, shardinfo
)
end
end

Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
Expand Down Expand Up @@ -356,6 +374,21 @@ function ConcreteIFRTArray(
return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding)
end

if isdefined(Base, :Memory)
function ConcreteIFRTArray(
data::Memory{T};
client::Union{Nothing,XLA.IFRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.IFRT.Device}=nothing,
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
) where {T}
theclient, thedevice = _select_client_and_device(client, idx, device, sharding)
sharded_data, shardinfo, padding = sharding(theclient, nothing, data)
shape = size(data)
return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo)
end
end

# Assemble data from multiple arrays. Needed in distributed setting where each process wont
# have enough host memory to hold all the arrays. We assume that the data is only provided
# for all of the addressable devices.
Expand Down Expand Up @@ -472,8 +505,11 @@ elseif XLA.REACTANT_XLA_RUNTIME == "IFRT"
ConcreteIFRTArray
end

@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} =
ConcreteRArray{T}(undef, Dims(shape); kwargs...)
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{
T
}(
undef, Dims(shape); kwargs...
)
Comment on lines +519 to +523
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/EnzymeAD/Reactant.jl/actions/runs/17949332627/job/51048248283?pr=1669#step:2:570

Suggested change
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{
T
}(
undef, Dims(shape); kwargs...
)
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} =
ConcreteRArray{T}(undef, Dims(shape); kwargs...)


"""
ConcreteRNumber(
Expand Down
82 changes: 82 additions & 0 deletions src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,30 @@ function Array(
return Array(buffer)
end

if isdefined(Base, :Memory)
function Array(
client::Client,
memory::Base.Memory{T},
device::Device=XLA.default_device(client),
memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))),
) where {T<:Reactant.ReactantPrimitive}
sizear = collect(Int64, reverse(size(memory)))
buffer = GC.@preserve memory sizear begin
@ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer(
client.client::Ptr{Cvoid},
pointer(memory)::Ptr{T},
XLA.primitive_type(T)::UInt64,
1::Csize_t,
sizear::Ptr{Int64},
0::Cint, # kAlwaysCopy
device.device::Ptr{Cvoid},
string(memory_kind)::Cstring,
)::Ptr{Cvoid}
end
return Array(buffer)
end
end

function Array(
client::Client, array::Base.Array{T,N}, sharding::Sharding
) where {T<:Reactant.ReactantPrimitive,N}
Expand Down Expand Up @@ -143,6 +167,47 @@ function Array(
return Array(buffer)
end

if isdefined(Base, :Memory)
function Array(
client::Client, memory::Base.Memory{T}, sharding::Sharding
) where {T<:Reactant.ReactantPrimitive}
all_devices = XLA.devices(sharding)
all_logical_device_ids = collect(Int64, 0:(length(all_devices) - 1))
hlo_sharding = convert(XLA.HloSharding, sharding)

slices, _ = XLA.sharding_to_concrete_array_indices(
hlo_sharding, size(memory), all_logical_device_ids
)

seen_slice = Dict{NTuple{N,UnitRange{Int64}},Int}()
host_buffers = Base.Array{T,1}[]
addressable_shard_indices = Vector{Int64}[]

cur_shard = 0
for (slice, device) in zip(slices, all_devices)
XLA.is_addressable(device) || continue

if haskey(seen_slice, slice)
idx = seen_slice[slice]
push!(addressable_shard_indices[idx], cur_shard)
else
host_buffer = let slice = memory[slice...]
slice isa Number ? collect(slice) : slice
end
push!(host_buffers, host_buffer)
push!(addressable_shard_indices, Int64[cur_shard])
seen_slice[slice] = length(host_buffers)
end

cur_shard += 1
end

return Array(
client, host_buffers, addressable_shard_indices, size(memory), sharding
)
end
end

function Array(
client::Client, array::Base.Array{T,N}, sharding
) where {T<:Reactant.ReactantPrimitive,N}
Expand All @@ -158,6 +223,23 @@ function Array(
return Array(client, array, ifrt_sharding)
end

if isdefined(Base, :Memory)
function Array(
client::Client, memory::Base.Memory{T}, sharding
) where {T<:Reactant.ReactantPrimitive}
@assert sharding isa Reactant.Sharding.AbstractSharding
if !(sharding isa Reactant.Sharding.HloSharding)
sharding = Reactant.Sharding.HloSharding(sharding, size(memory))
end

(; hlo_sharding, mesh) = sharding
devices = XLA.get_device.((client,), mesh.device_ids)
ifrt_sharding = Sharding([devices...], hlo_sharding)

return Array(client, memory, ifrt_sharding)
end
end

@inline function XLA.free_buffer(buffer::Array)
if buffer.buffer != C_NULL
@ccall MLIR.API.mlir_c.ifrt_free_array(buffer.buffer::Ptr{Cvoid})::Cvoid
Expand Down
17 changes: 17 additions & 0 deletions src/xla/PJRT/Buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
return Buffer(buffer)
end

if isdefined(Base, :Memory)
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
sizear = collect(Int64, reverse(size(memory)))
buffer = GC.@preserve memory sizear begin
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
client.client::Ptr{Cvoid},
pointer(memory)::Ptr{T},
XLA.primitive_type(T)::UInt64,
1::Csize_t,
pointer(sizear)::Ptr{Int64},
device.device::Ptr{Cvoid},
)::Ptr{Cvoid}
end
return Buffer(buffer)
end
end

function Base.similar(a::Buffer)
buffer = GC.@preserve a begin
@ccall MLIR.API.mlir_c.UninitPJRTBuffer(
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand Down
17 changes: 17 additions & 0 deletions test/integration/fixedsizearrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

using Reactant, Test, FixedSizeArrays

fn(x, y) = (2 .* x .- 3) * y'

@testset "FixedSizeArrays" begin
@testset "1D" begin
x = FixedSizeArray(fill(3.0f0, 100))
rx = Reactant.to_rarray(x)
@test @jit(fn(rx, rx)) ≈ fn(x, x)
end
@testset "2D" begin
x = FixedSizeArray(fill(3.0f0, (4, 5)))
rx = Reactant.to_rarray(x)
@test @jit(fn(rx, rx)) ≈ fn(x, x)
end
end
10 changes: 10 additions & 0 deletions test/memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Reactant, Test

fn(x, y) = sin.(x) .+ cos.(y)

@testset "Memory test" begin
x = Memory{Float32}(fill(2.0f0, 10))
x_ra = Reactant.to_rarray(x)

@test @jit(fn(x_ra, x_ra)) ≈ fn(x, x)
end
Loading
Loading