@@ -39,28 +39,27 @@ function ∇batched_mul(Δ::AbstractArray{T, 3}, A::AbstractArray{T, 3}, B::Abst
39
39
(batched_mul (Δ, batched_transpose (B)), batched_mul (batched_transpose (A), Δ))
40
40
end
41
41
42
-
43
42
function ∇batched_mul (Δ:: AbstractArray{T, 3} , A:: BatchedTranspose{T, <: AbstractArray{T, 3}} , B:: AbstractArray{T, 3} ) where T
44
- (batched_mul (Δ , batched_transpose (B )), batched_mul (A , Δ))
43
+ (batched_mul (B , batched_transpose (Δ )), batched_mul (batched_transpose (A) , Δ))
45
44
end
46
45
47
46
function ∇batched_mul (Δ:: AbstractArray{T, 3} , A:: AbstractArray{T, 3} , B:: BatchedTranspose{T, <: AbstractArray{T, 3}} ) where T
48
- (batched_mul (Δ, B) , batched_mul (batched_transpose (A ), Δ ))
47
+ (batched_mul (Δ, batched_transpose (B)) , batched_mul (batched_transpose (Δ ), A ))
49
48
end
50
49
51
50
function ∇batched_mul (Δ:: AbstractArray{T, 3} , A:: BatchedTranspose{T, <: AbstractArray{T, 3}} , B:: BatchedTranspose{T, <: AbstractArray{T, 3}} ) where T
52
- (batched_mul (batched_transpose (Δ ), batched_transpose (B )), batched_mul (batched_transpose (A ), batched_transpose (Δ )))
51
+ (batched_mul (batched_transpose (B ), batched_transpose (Δ )), batched_mul (batched_transpose (Δ ), batched_transpose (A )))
53
52
end
54
53
55
54
function ∇batched_mul (Δ:: AbstractArray{T, 3} , A:: BatchedAdjoint{T, <: AbstractArray{T, 3}} , B:: AbstractArray{T, 3} ) where T
56
- (batched_mul (Δ , batched_adjoint (B )), batched_mul (A , Δ))
55
+ (batched_mul (B , batched_adjoint (Δ )), batched_mul (batched_adjoint (A) , Δ))
57
56
end
58
57
59
58
function ∇batched_mul (Δ:: AbstractArray{T, 3} , A:: AbstractArray{T, 3} , B:: BatchedAdjoint{T, <: AbstractArray{T, 3}} ) where T
60
- (batched_mul (Δ, B) , batched_mul (batched_adjoint (A ), Δ ))
59
+ (batched_mul (Δ, batched_adjoint (B)) , batched_mul (batched_adjoint (Δ ), A ))
61
60
end
62
61
63
62
function ∇batched_mul (Δ:: AbstractArray{T, 3} , A:: BatchedAdjoint{T, <: AbstractArray{T, 3}} , B:: BatchedAdjoint{T, <: AbstractArray{T, 3}} ) where T
64
- (batched_mul (batched_adjoint (Δ ), batched_adjoint (B )), batched_mul (batched_adjoint (A ), batched_adjoint (Δ )))
63
+ (batched_mul (batched_adjoint (B ), batched_adjoint (Δ )), batched_mul (batched_adjoint (Δ ), batched_adjoint (A )))
65
64
end
66
65
0 commit comments