Skip to content

Commit 8435c5e

Browse files
authored
feat: generalize indexing to all wrappers (#146)
* feat: generalize indexing to all wrappers * test: use `PermuteDimsArray` to test parentindices
1 parent fd9b469 commit 8435c5e

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/TracedRArray.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ ancestor(x::TracedRArray) = x
3434
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))
3535

3636
get_ancestor_indices(::TracedRArray, indices...) = indices
37-
function get_ancestor_indices(
38-
x::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices...
39-
) where {T,N}
40-
return get_ancestor_indices(parent(x), Base.reindex(x.indices, indices)...)
37+
function get_ancestor_indices(x::WrappedTracedRArray, indices...)
38+
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
4139
end
4240

4341
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a

test/wrapped_arrays.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,17 @@ end
113113

114114
@test btranspose_badjoint_compiled(x_ra) btranspose_badjoint(x)
115115
end
116+
117+
function bypass_permutedims(x)
118+
x = PermutedDimsArray(x, (2, 1, 3)) # Don't use permutedims here
119+
return view(x, 2:3, 1:2, :)
120+
end
121+
122+
@testset "PermutedDimsArray" begin
123+
x = rand(4, 4, 3)
124+
x_ra = Reactant.to_rarray(x)
125+
126+
bypass_permutedims_compiled = @compile bypass_permutedims(x_ra)
127+
128+
@test bypass_permutedims_compiled(x_ra) bypass_permutedims(x)
129+
end

0 commit comments

Comments
 (0)