Skip to content

Commit fbcb5bb

Browse files
committed
feat: support statistics
1 parent 02b608a commit fbcb5bb

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

ext/ReactantStatisticsExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
module ReactantStatisticsExt
22

3-
using Reactant: TracedRArray
3+
using Reactant: AnyTracedRArray, materialize_traced_array
44
using Statistics: Statistics
55

6-
function Statistics.mean(A::TracedRArray{T,N}; dims=:) where {T,N}
6+
function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N}
7+
A = materialize_traced_array(A)
78
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
89
return mapreduce(identity, +, A; dims) / denom
910
end
1011

1112
function Statistics.var(
12-
A::TracedRArray{T,N}; dims=:, mean=nothing, corrected=true
13+
A::AnyTracedRArray{T,N}; dims=:, mean=nothing, corrected=true
1314
) where {T,N}
15+
A = materialize_traced_array(A)
1416
mean === nothing && (mean = Statistics.mean(A; dims))
1517
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected
1618
return mapreduce(abs2, +, A .- mean; dims) / denom

src/TracedRArray.jl

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1616
end
1717
end
1818

19-
const AnyTracedRArray{T,N} = Union{
20-
TracedRArray{T,N},WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
21-
}
19+
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
20+
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
2221
const AnyTracedRScalar{T} = AnyTracedRArray{T,0}
2322
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
2423
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
2524
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
2625

26+
materialize_traced_array(x::TracedRArray) = x
27+
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
28+
2729
get_mlir_data(x::TracedRArray) = x.mlir_data
28-
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(x[axes(x)...])
30+
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x))
2931

3032
ancestor(x::TracedRArray) = x
31-
function ancestor(x::WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}) where {T,N}
32-
return ancestor(parent(x))
33-
end
33+
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))
3434

3535
get_ancestor_indices(::TracedRArray, indices...) = indices
3636
function get_ancestor_indices(
@@ -88,15 +88,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
8888
end
8989

9090
# Prevent ambiguity
91-
function Base.getindex(
92-
a::WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}}, index::Int...
93-
) where {T,N}
91+
function Base.getindex(a::WrappedTracedRArray, index::Int...)
9492
return getindex(ancestor(a), get_ancestor_indices(a, index...)...)
9593
end
9694

97-
function Base.getindex(
98-
a::WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}}, indices...
99-
) where {T,N}
95+
function Base.getindex(a::WrappedTracedRArray, indices...)
10096
return getindex(ancestor(a), get_ancestor_indices(a, indices...)...)
10197
end
10298

@@ -783,14 +779,8 @@ function broadcast_to_size(arg::AbstractArray, rsize)
783779
return arg
784780
end
785781

786-
function broadcast_to_size(
787-
arg::WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}}, rsize
788-
) where {T,N}
789-
return broadcast_to_size(arg[axes(arg)...], rsize)
790-
end
791-
792-
function broadcast_to_size(arg::TracedRArray, rsize)
793-
return arg
782+
function broadcast_to_size(arg::AnyTracedRArray, rsize)
783+
return materialize_traced_array(arg)
794784
end
795785

796786
function broadcast_to_size(arg::Base.RefValue, rsize)

test/wrapped_arrays.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Reactant, Test
1+
using Reactant, Test, Statistics
22

33
function view_getindex_1(x)
44
x = view(x, 2:3, 1:2, :)
@@ -80,3 +80,21 @@ end
8080
@test bcast_compiled(op, x_ra) bcast_wrapper(op, x)
8181
end
8282
end
83+
84+
function mean_var(x)
85+
x = view(x, 2:3, :)
86+
return mean(x; dims=1), var(x; dims=1)
87+
end
88+
89+
@testset "mean/var" begin
90+
x = rand(4, 3)
91+
x_ra = Reactant.to_rarray(x)
92+
93+
mean_var_compiled = @compile mean_var(x_ra)
94+
95+
m1, v1 = mean_var(x)
96+
m2, v2 = mean_var_compiled(x_ra)
97+
98+
@test m1 m2
99+
@test v1 v2
100+
end

0 commit comments

Comments
 (0)