Skip to content

Commit bbdc647

Browse files
committed
fix bmm gradient
1 parent 5042189 commit bbdc647

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/batchedmul.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,27 @@ function ∇batched_mul(Δ::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::Abst
3939
(batched_mul(Δ, batched_transpose(B)), batched_mul(batched_transpose(A), Δ))
4040
end
4141

42-
4342
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::AbstractArray{T, 3}) where T
44-
(batched_mul(Δ, batched_transpose(B)), batched_mul(A, Δ))
43+
(batched_mul(B, batched_transpose(Δ)), batched_mul(batched_transpose(A), Δ))
4544
end
4645

4746
function ∇batched_mul::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::BatchedTranspose{T, <: AbstractArray{T, 3}}) where T
48-
(batched_mul(Δ, B), batched_mul(batched_transpose(A), Δ))
47+
(batched_mul(Δ, batched_transpose(B)), batched_mul(batched_transpose(Δ), A))
4948
end
5049

5150
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::BatchedTranspose{T, <: AbstractArray{T, 3}}) where T
52-
(batched_mul(batched_transpose(Δ), batched_transpose(B)), batched_mul(batched_transpose(A), batched_transpose(Δ)))
51+
(batched_mul(batched_transpose(B), batched_transpose(Δ)), batched_mul(batched_transpose(Δ), batched_transpose(A)))
5352
end
5453

5554
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedAdjoint{T, <: AbstractArray{T, 3}}, B::AbstractArray{T, 3}) where T
56-
(batched_mul(Δ, batched_adjoint(B)), batched_mul(A, Δ))
55+
(batched_mul(B, batched_adjoint(Δ)), batched_mul(batched_adjoint(A), Δ))
5756
end
5857

5958
function ∇batched_mul::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::BatchedAdjoint{T, <: AbstractArray{T, 3}}) where T
60-
(batched_mul(Δ, B), batched_mul(batched_adjoint(A), Δ))
59+
(batched_mul(Δ, batched_adjoint(B)), batched_mul(batched_adjoint(Δ), A))
6160
end
6261

6362
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedAdjoint{T, <: AbstractArray{T, 3}}, B::BatchedAdjoint{T, <: AbstractArray{T, 3}}) where T
64-
(batched_mul(batched_adjoint(Δ), batched_adjoint(B)), batched_mul(batched_adjoint(A), batched_adjoint(Δ)))
63+
(batched_mul(batched_adjoint(B), batched_adjoint(Δ)), batched_mul(batched_adjoint(Δ), batched_adjoint(A)))
6564
end
6665

0 commit comments

Comments
 (0)