Skip to content

Commit 5cab70e

Browse files
committed
feat: handle permutedims
1 parent 46c4345 commit 5cab70e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/TracedRArray.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,
149149
return TracedRArray{T,NT}((), res3, dims)
150150
end
151151

152-
function Base.permutedims(A::TracedRArray{T,N}, perm) where {T,N}
152+
function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
153153
return TracedRArray{T,N}(
154154
(),
155155
MLIR.IR.result(
156156
MLIR.Dialects.stablehlo.transpose(
157-
A.mlir_data;
157+
get_mlir_data(A);
158158
permutation=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in perm]),
159159
),
160160
1,
@@ -169,7 +169,7 @@ function Base.promote_rule(
169169
return TracedRArray{Base.promote_type(T, S),N}
170170
end
171171

172-
function Base.promote_rule(A::Type{T}, B::Type{TracedRArray{S,N}}) where {T,S,N}
172+
function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
173173
return TracedRArray{Base.promote_type(T, S),N}
174174
end
175175

0 commit comments

Comments
 (0)