Skip to content

Commit f9a1bf8

Browse files
committed
fix: special case transpose and adjoint
1 parent 0b803e9 commit f9a1bf8

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/TracedRArray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
179179
)
180180
end
181181

182+
function Base.transpose(A::AnyTracedRVecOrMat)
183+
A = ndims(A) == 1 ? reshape(A, :, 1) : A
184+
return permutedims(A, (2, 1))
185+
end
186+
Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)
187+
182188
function Base.promote_rule(
183189
::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}}
184190
) where {T,S,N}

test/basic.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,15 @@ tuple_byref2(x) = abs2.(x), tuple_byref2(x)
261261
# @test r2[2].a.b.data === x.data
262262
# @test r2[1] == abs2.([1.0 -2.0; -3.0 4.0])
263263
end
264+
265+
sum_xxᵀ(x) = sum(x .* x')
266+
267+
@testset "sum(x .* x')" begin
268+
@testset "size(x): $(size(x))" for x in (rand(4, 4), rand(4,))
269+
x_ca = Reactant.to_rarray(x)
270+
271+
sum_xxᵀ_compiled = @compile sum_xxᵀ(x_ca)
272+
273+
@test sum_xxᵀ_compiled(x_ca) sum_xxᵀ(x)
274+
end
275+
end

0 commit comments

Comments
 (0)