Skip to content

Commit e8fe0c8

Browse files
committed
feat: support batched transpose and adjoint
1 parent af21f68 commit e8fe0c8

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,7 @@ function NNlib.meanpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
203203
T(numel)
204204
end
205205

206+
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
207+
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)
208+
206209
end # module ReactantNNlibExt

test/wrapped_arrays.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Reactant, Test, Statistics
1+
using Reactant, Test, Statistics, NNlib
22

33
function view_getindex_1(x)
44
x = view(x, 2:3, 1:2, :)
@@ -98,3 +98,18 @@ end
9898
@test m1 m2
9999
@test v1 v2
100100
end
101+
102+
function btranspose_badjoint(x)
103+
x1 = NNlib.batched_transpose(x)
104+
x2 = NNlib.batched_adjoint(x)
105+
return x1 .+ x2
106+
end
107+
108+
@testset "batched transpose/adjoint" begin
109+
x = rand(4, 2, 3)
110+
x_ra = Reactant.to_rarray(x)
111+
112+
btranspose_badjoint_compiled = @compile btranspose_badjoint(x_ra)
113+
114+
@test btranspose_badjoint_compiled(x_ra) btranspose_badjoint(x)
115+
end

0 commit comments

Comments
 (0)