Skip to content

Commit 7f6ea50

Browse files
authored
Refactor batched_vec (#464)
Branching on the type of the second argument caused a subtle performance bug when differentiating via `Zygote`; see #462
1 parent 16b7486 commit 7f6ea50

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/batched/batchedmul.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,12 @@ julia> batched_vec(A,b) |> size
172172
(16, 32)
173173
```
174174
"""
175-
function batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix)
176-
# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:
177-
if B isa AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}
178-
return batched_vec(A, copy(B))
179-
end
175+
batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =
180176
reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))
181-
end
177+
178+
# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:
179+
batched_vec(A::AbstractArray{T,3} where T, B::AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}) =
180+
batched_vec(A, copy(B))
182181

183182
batched_vec(A::AbstractArray{T,3} where T, b::AbstractVector) =
184183
reshape(batched_mul(A, reshape(b, length(b), 1, 1)), size(A,1), size(A,3))

0 commit comments

Comments
 (0)