Skip to content

Commit 58142d5

Browse files
committed
feat: handle reshaping
1 parent fc9b577 commit 58142d5

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/TracedRArray.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ const AnyTracedRVector{T} = AnyTracedRArray{T,1}
2424
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
2525
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
2626

27+
get_mlir_data(x::TracedRArray) = x.mlir_data
28+
get_mlir_data(x::AnyTracedRArray) = get_mlir_data(x[axes(x)...])
29+
2730
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
2831

2932
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
@@ -118,13 +121,13 @@ end
118121

119122
Base.only(A::AnyTracedRScalar{T}) where {T} = A
120123

121-
function Base.reshape(A::TracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
124+
function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
122125
prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A)))
123126

124127
# HLO reshape semantics collapse the opposite way
125128
res1 = MLIR.IR.result(
126129
MLIR.Dialects.stablehlo.transpose(
127-
A.mlir_data;
130+
get_mlir_data(A);
128131
permutation=MLIR.IR.DenseArrayAttribute([Int64(N - 1 - i) for i in 0:(N - 1)]),
129132
),
130133
1,

test/wrapped_arrays.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,19 @@ end
3232

3333
@test view_getindex_3_compiled(x_ra) view_getindex_3(x)
3434
end
35+
36+
function reshape_wrapper(x)
37+
x = view(x, 2:3, 1:2, :)
38+
return reshape(x, 4, :)
39+
end
40+
41+
@testset "reshape wrapper" begin
42+
x = rand(4, 4, 3)
43+
x_ra = Reactant.to_rarray(x)
44+
45+
reshape_wrapper(x)
46+
47+
reshape_wrapper_compiled = @compile reshape_wrapper(x_ra)
48+
49+
@test reshape_wrapper_compiled(x_ra) reshape_wrapper(x)
50+
end

0 commit comments

Comments
 (0)