- 
                Notifications
    You must be signed in to change notification settings 
- Fork 33
FixedSizeArrays support (And Memory Support) #1669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e0906bc
              460dd4c
              73205e2
              15eca0b
              cbd5e75
              702ac1d
              658a3d7
              7e79d15
              08a425e
              8b9f24d
              dc00c6a
              ff2d39b
              995c8a3
              8ef1575
              c1842c3
              a033cc6
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,36 @@ | ||||||||||||||||||||||||||||||||||||||||||
| module ReactantFixedSizeArraysExt | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| using FixedSizeArrays | ||||||||||||||||||||||||||||||||||||||||||
| using Reactant | ||||||||||||||||||||||||||||||||||||||||||
| using Reactant: TracedRArray, TracedRNumber, Ops | ||||||||||||||||||||||||||||||||||||||||||
| using ReactantCore: ReactantCore | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| function Reactant.traced_type_inner( | ||||||||||||||||||||||||||||||||||||||||||
| @nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}), | ||||||||||||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 | ||||||||||||||||||||||||||||||||||||||||||
| seen, | ||||||||||||||||||||||||||||||||||||||||||
| @nospecialize(mode::Reactant.TraceMode), | ||||||||||||||||||||||||||||||||||||||||||
| @nospecialize(track_numbers::Type), | ||||||||||||||||||||||||||||||||||||||||||
| @nospecialize(sharding), | ||||||||||||||||||||||||||||||||||||||||||
| @nospecialize(runtime) | ||||||||||||||||||||||||||||||||||||||||||
| ) where {T,N} | ||||||||||||||||||||||||||||||||||||||||||
| T2 = Reactant.TracedRNumber{T} | ||||||||||||||||||||||||||||||||||||||||||
| return FixedSizeArrays.FixedSizeArrayDefault{T2,N} | ||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| Base.@nospecializeinfer function Reactant.make_tracer( | ||||||||||||||||||||||||||||||||||||||||||
| seen, | ||||||||||||||||||||||||||||||||||||||||||
| @nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}), | ||||||||||||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 | ||||||||||||||||||||||||||||||||||||||||||
| @nospecialize(path), | ||||||||||||||||||||||||||||||||||||||||||
| mode; | ||||||||||||||||||||||||||||||||||||||||||
| kwargs..., | ||||||||||||||||||||||||||||||||||||||||||
| ) where {T,N} | ||||||||||||||||||||||||||||||||||||||||||
| shape = size(prev) | ||||||||||||||||||||||||||||||||||||||||||
| return reshape( | ||||||||||||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but array itself we override to keep the dimensionality and allocate everything ourselves for Array. if the actual tracing of a fixedsizearray does the current "generic recursion into structs" it will eventually allocate a 1-dim memory, always There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole point of  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels to me like  Lines 1177 to 1196 in e4bb34f 
 size(prev)but could be overridden if passed explicitly.There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opened #1712 to implement my suggestion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I'm missing something fundamental:  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Array maps to our Concrete/Traced Array. For all "wrapped" types, we need to preserve the wrapper type, if we want the outputs to preserve the wrapped type. Else any operation you perform on a FixedSizeArray with inevitably be mapped to an Array output There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not really convinced a proposal for changing the memory backend in FixedSizeArrays is going to fly: it's meant to follow Array very closely, and the memory backend is always a flattened dense vector. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean once #1696 goes in, I think we should be able to make that work (though the backing memory might need to be a reshape(tracedrarray{n}, 1), but the reshape will optimize out with the one that's emitted currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ReshapedArray unfortunately is a AbstractArray and not a DenseArray There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing the call to FixedSizedArray There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returning a FixedSizeArray causes problems. With  I don't know why it's throwing this, as after inspection the inner  Without it I get: For this one I don't have a clue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So: mem = Memory{Float32}([1.0f0, 2.0f0])
getfield(mem, 2)This is where the pointer is coming from. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since Reactant isn't supporting resizing for normal arrays, and FixedSizeArray isn't doing any fancy memory/optimization stuff other than fixed size (unlike OneHotArrays, FillArrays) shouldn't it be fine for the FixedSizeArray to transform into an ordinary Concrete Array? | ||||||||||||||||||||||||||||||||||||||||||
| Reactant.make_tracer( | ||||||||||||||||||||||||||||||||||||||||||
| seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number | ||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||
| shape, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -214,8 +214,8 @@ function ConcretePJRTArray{T,N}( | |||||||||||||||
| return ConcretePJRTArray{T,N,D,typeof(sharding)}(data, shape, sharding) | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| function ConcretePJRTArray( | ||||||||||||||||
| data::Array{T,N}; | ||||||||||||||||
| function make_concrete_PJRT_array( | ||||||||||||||||
| data::AbstractArray{T,N}, | ||||||||||||||||
| client::Union{Nothing,XLA.PJRT.Client}=nothing, | ||||||||||||||||
| idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
| device::Union{Nothing,XLA.PJRT.Device}=nothing, | ||||||||||||||||
|  | @@ -228,6 +228,28 @@ function ConcretePJRTArray( | |||||||||||||||
| return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| function ConcretePJRTArray( | ||||||||||||||||
| data::Array{T,N}; | ||||||||||||||||
| client::Union{Nothing,XLA.PJRT.Client}=nothing, | ||||||||||||||||
| idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
| device::Union{Nothing,XLA.PJRT.Device}=nothing, | ||||||||||||||||
| sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
| ) where {T,N} | ||||||||||||||||
| return make_concrete_PJRT_array(data, client, idx, device, sharding) | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| if isdefined(Base, :Memory) | ||||||||||||||||
| function ConcretePJRTArray( | ||||||||||||||||
| data::Memory{T}; | ||||||||||||||||
| client::Union{Nothing,XLA.PJRT.Client}=nothing, | ||||||||||||||||
| idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
| device::Union{Nothing,XLA.PJRT.Device}=nothing, | ||||||||||||||||
| sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
| ) where {T} | ||||||||||||||||
| return make_concrete_PJRT_array(data, client, idx, device, sharding) | ||||||||||||||||
| end | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) | ||||||||||||||||
| XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) | ||||||||||||||||
| function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) | ||||||||||||||||
|  | @@ -342,8 +364,8 @@ function ConcreteIFRTArray{T,N}( | |||||||||||||||
| return ConcreteIFRTArray{T,N,typeof(sharding)}(data, shape, sharding) | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| function ConcreteIFRTArray( | ||||||||||||||||
| data::Array{T,N}; | ||||||||||||||||
| function make_concrete_IFRT_array( | ||||||||||||||||
| data::AbstractArray{T,N}, | ||||||||||||||||
| client::Union{Nothing,XLA.IFRT.Client}=nothing, | ||||||||||||||||
| idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
| device::Union{Nothing,XLA.IFRT.Device}=nothing, | ||||||||||||||||
|  | @@ -356,6 +378,28 @@ function ConcreteIFRTArray( | |||||||||||||||
| return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding) | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| function ConcreteIFRTArray( | ||||||||||||||||
| data::Array{T,N}; | ||||||||||||||||
| client::Union{Nothing,XLA.IFRT.Client}=nothing, | ||||||||||||||||
| idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
| device::Union{Nothing,XLA.IFRT.Device}=nothing, | ||||||||||||||||
| sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
| ) where {T,N} | ||||||||||||||||
| return make_concrete_IFRT_array(data, client, idx, device, sharding) | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| if isdefined(Base, :Memory) | ||||||||||||||||
| function ConcreteIFRTArray( | ||||||||||||||||
| data::Memory{T}; | ||||||||||||||||
| client::Union{Nothing,XLA.IFRT.Client}=nothing, | ||||||||||||||||
| idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
| device::Union{Nothing,XLA.IFRT.Device}=nothing, | ||||||||||||||||
| sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
| ) where {T} | ||||||||||||||||
| return make_concrete_IFRT_array(data, client, idx, device, sharding) | ||||||||||||||||
| end | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| # Assemble data from multiple arrays. Needed in distributed setting where each process wont | ||||||||||||||||
| # have enough host memory to hold all the arrays. We assume that the data is only provided | ||||||||||||||||
| # for all of the addressable devices. | ||||||||||||||||
|  | @@ -472,8 +516,11 @@ elseif XLA.REACTANT_XLA_RUNTIME == "IFRT" | |||||||||||||||
| ConcreteIFRTArray | ||||||||||||||||
| end | ||||||||||||||||
|  | ||||||||||||||||
| @inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = | ||||||||||||||||
| ConcreteRArray{T}(undef, Dims(shape); kwargs...) | ||||||||||||||||
| @inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{ | ||||||||||||||||
| T | ||||||||||||||||
| }( | ||||||||||||||||
| undef, Dims(shape); kwargs... | ||||||||||||||||
| ) | ||||||||||||||||
| 
      Comment on lines
    
      +519
     to 
      +523
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/EnzymeAD/Reactant.jl/actions/runs/17949332627/job/51048248283?pr=1669#step:2:570 
        Suggested change
       
 | ||||||||||||||||
|  | ||||||||||||||||
| """ | ||||||||||||||||
| ConcreteRNumber( | ||||||||||||||||
|  | ||||||||||||||||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
|  | ||
| using Reactant, Test, FixedSizeArrays | ||
|  | ||
| fn(x, y) = (2 .* x .- 3) * y' | ||
|  | ||
| @testset "FixedSizeArrays" begin | ||
| @testset "1D" begin | ||
| x = FixedSizeArray(fill(3.0f0, 100)) | ||
| rx = Reactant.to_rarray(x) | ||
| @test @jit(fn(rx, rx)) ≈ fn(x, x) | ||
| end | ||
| @testset "2D" begin | ||
| x = FixedSizeArray(fill(3.0f0, (4, 5))) | ||
| rx = Reactant.to_rarray(x) | ||
| @test @jit(fn(rx, rx)) ≈ fn(x, x) | ||
| end | ||
| end | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| using Reactant, Test | ||
|  | ||
| fn(x, y) = sin.(x) .+ cos.(y) | ||
|  | ||
| @testset "Memory test" begin | ||
| x = Memory{Float32}(fill(2.0f0, 10)) | ||
| x_ra = Reactant.to_rarray(x) | ||
|  | ||
| @test @jit(fn(x_ra, x_ra)) ≈ fn(x, x) | ||
| end | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.