Skip to content

Commit 311498b

Browse files
authored
feat: define outer repeat method for TracedRArray (EnzymeAD#361)
* Add repeat method * Add repeat tests * Update test/basic.jl * Update src/TracedRArray.jl
1 parent 8b90501 commit 311498b

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

src/TracedRArray.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,27 @@ end
729729

730730
Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x)
731731
Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x)
732+
733+
# outer repeat
734+
function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M}
735+
P = max(N, M) # potentially padded
736+
737+
# (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1)
738+
interleaved_size = ones(Int, 2P)
739+
interleaved_size[1:2:2N] .= size(x)
740+
741+
x_interleaved = reshape(x, interleaved_size...)
742+
743+
# (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP)
744+
broadcast_target_size = interleaved_size
745+
broadcast_target_size[2:2:2M] .= counts
746+
747+
x_broadcasted = broadcast_to_size(x_interleaved, broadcast_target_size)
748+
749+
# (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP)
750+
final_size = vec(prod(reshape(broadcast_target_size, 2, :), dims=1))
751+
752+
x_final = reshape(x_broadcasted, final_size...)
753+
754+
return x_final
755+
end

test/basic.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,16 @@ end
364364
end
365365
end
366366

367+
@testset "repeat" begin
368+
@testset for (size, counts) in Iterators.product(
369+
[(2,), (2,3), (2,3,4), (2,3,4,5)],
370+
[(), (1,), (2,), (2,1), (1,2), (2,2), (2,2,2), (1,1,1,1,1)]
371+
)
372+
x = rand(size...)
373+
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)
374+
end
375+
end
376+
367377
function update_on_copy(x)
368378
y = x[1:2, 2:4, :]
369379
y[1:1, 1:1, :] = ones(1, 1, 3)

0 commit comments

Comments
 (0)