Skip to content

Commit a90efff

Browse files
committed
feat: broadcasting on views
1 parent 5cab70e commit a90efff

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

src/TracedRArray.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ function Base.mapreducedim!(
680680
end
681681

682682
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
683+
683684
AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}()
684685
AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle{N}()
685686

@@ -690,7 +691,9 @@ AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle
690691
# copy(inst)
691692
# end
692693

693-
BroadcastStyle(::Type{T}) where {T<:TracedRArray} = AbstractReactantArrayStyle{ndims(T)}()
694+
function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N}
695+
return AbstractReactantArrayStyle{N}()
696+
end
694697

695698
function Base.similar(
696699
bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
@@ -758,6 +761,12 @@ function broadcast_to_size(arg::AbstractArray, rsize)
758761
return arg
759762
end
760763

764+
function broadcast_to_size(
765+
arg::WrappedArray{T,N,TracedRArray,<:TracedRArray{T,N}}, rsize
766+
) where {T,N}
767+
return broadcast_to_size(arg[axes(arg)...], rsize)
768+
end
769+
761770
function broadcast_to_size(arg::TracedRArray, rsize)
762771
return arg
763772
end

test/wrapped_arrays.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,35 @@ end
4848

4949
@test reshape_wrapper_compiled(x_ra) reshape_wrapper(x)
5050
end
51+
52+
function permutedims_wrapper(x)
53+
x = view(x, 2:3, 1:2, :)
54+
return permutedims(x, (2, 1, 3))
55+
end
56+
57+
@testset "permutedims wrapper" begin
58+
x = rand(4, 4, 3)
59+
x_ra = Reactant.to_rarray(x)
60+
61+
permutedims_wrapper(x)
62+
63+
permutedims_wrapper_compiled = @compile permutedims_wrapper(x_ra)
64+
65+
@test permutedims_wrapper_compiled(x_ra) permutedims_wrapper(x)
66+
end
67+
68+
function bcast_wrapper(f::F, x) where {F}
69+
x = view(x, 2:3, :)
70+
return f.(x)
71+
end
72+
73+
@testset "Broadcasting on wrapped arrays" begin
74+
x = rand(4, 3)
75+
x_ra = Reactant.to_rarray(x)
76+
77+
for op in (-, tanh, sin)
78+
bcast_compiled = @compile bcast_wrapper(op, x_ra)
79+
80+
@test bcast_compiled(op, x_ra) bcast_wrapper(op, x)
81+
end
82+
end

0 commit comments

Comments
 (0)