Skip to content

Commit 565a1a2

Browse files
authored
feat: overload circshift (#1386)
* feat: overload circshift * test: fftshift
1 parent 7d3511b commit 565a1a2

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/TracedRArray.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,4 +1410,27 @@ function Base._reverse!(a::AnyTracedRArray{T,N}, dims::NTuple{M,Int}) where {T,N
14101410
return a
14111411
end
14121412

1413+
function Base.circshift!(
1414+
dest::AnyTracedRArray{T,N}, src, shiftamt::Base.DimsInteger
1415+
) where {T,N}
1416+
src = TracedUtils.promote_to(TracedRArray{T,N}, materialize_traced_array(src))
1417+
shiftamt = Base.fill_to_length(shiftamt, 0, Val(N))
1418+
1419+
for i in 1:N
1420+
amt = shiftamt[i] % size(src, i)
1421+
amt == 0 && continue
1422+
if amt > 0
1423+
src1 = selectdim(src, i, (size(src, i) - amt + 1):size(src, i))
1424+
src2 = selectdim(src, i, 1:(size(src, i) - amt))
1425+
else
1426+
src1 = selectdim(src, i, (-amt + 1):size(src, i))
1427+
src2 = selectdim(src, i, 1:(-amt))
1428+
end
1429+
src = cat(src1, src2; dims=i)
1430+
end
1431+
1432+
copyto!(dest, src)
1433+
return dest
1434+
end
1435+
14131436
end

test/basic.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,3 +1337,13 @@ sameunitrange(x, y) = first(x) == first(y) && last(x) == last(y)
13371337
end
13381338
end
13391339
end
1340+
1341+
@testset "circshift" begin
1342+
x = reshape(collect(Float32, 1:36), 2, 6, 3)
1343+
x_ra = Reactant.to_rarray(x)
1344+
1345+
@test @jit(circshift(x_ra, (1, 2))) circshift(x, (1, 2))
1346+
@test @jit(circshift(x_ra, (1, 2, 3))) circshift(x, (1, 2, 3))
1347+
@test @jit(circshift(x_ra, (-3, 2))) circshift(x, (-3, 2))
1348+
@test @jit(circshift(x_ra, (5, 2))) circshift(x, (5, 2))
1349+
end

test/integration/fft.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ using FFTW, Reactant, Test
2121
y_ra = @jit(fft(x_ra))
2222
@test @jit(ifft(y_ra)) x
2323

24+
shifted_fft = @jit(fftshift(y_ra))
25+
@test shifted_fft fftshift(Array(y_ra))
26+
@test @jit(ifftshift(shifted_fft)) Array(y_ra)
27+
2428
@testset "fft real input" begin
2529
x = rand(Float32, 2, 3, 4)
2630
x_ra = Reactant.to_rarray(x)

0 commit comments

Comments
 (0)