Skip to content

Commit d101848

Browse files
committed
gradient for adjoint bmm
1 parent 2bd6d7e commit d101848

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/batchedmul.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,16 @@ end
5151
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedTranspose{T, <: AbstractArray{T, 3}}, B::BatchedTranspose{T, <: AbstractArray{T, 3}}) where T
5252
(batched_mul(batched_transpose(Δ), batched_transpose(B)), batched_mul(batched_transpose(A), batched_transpose(Δ)))
5353
end
54+
55+
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedAdjoint{T, <: AbstractArray{T, 3}}, B::AbstractArray{T, 3}) where T
56+
(batched_mul(Δ, batched_adjoint(B)), batched_mul(A, Δ))
57+
end
58+
59+
function ∇batched_mul::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::BatchedAdjoint{T, <: AbstractArray{T, 3}}) where T
60+
(batched_mul(Δ, B), batched_mul(batched_adjoint(A), Δ))
61+
end
62+
63+
function ∇batched_mul::AbstractArray{T, 3}, A::BatchedAdjoint{T, <: AbstractArray{T, 3}}, B::BatchedAdjoint{T, <: AbstractArray{T, 3}}) where T
64+
(batched_mul(batched_adjoint(Δ), batched_adjoint(B)), batched_mul(batched_adjoint(A), batched_adjoint(Δ)))
65+
end
66+

0 commit comments

Comments
 (0)