Skip to content

Commit 64db424

Browse files
author
Michael Abbott
committed
promote types
1 parent 2d6518f commit 64db424

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/batched/batchedmul.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for
1111
"""
1212
function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
1313
axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch"))
14-
C = similar(A, (axes(A, 1), axes(B, 2), axes(A, 3)))
14+
T = promote_type(T1, T2)
15+
C = similar(A, T, (axes(A, 1), axes(B, 2), axes(A, 3)))
1516
batched_mul!(C, A, B)
1617
end
1718

test/batchedmul.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ end
5252
iB = TB.(rand(1:99, 5,7,3))
5353
iC = zeros(Int, 7,6,3)
5454
@test batched_mul(iA, iB) == bmm_adjtest(iA, iB)
55+
@test batched_mul(cA, iB) bmm_adjtest(cA, iB)
5556

5657
@test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 2,2,10))
5758
@test_throws DimensionMismatch batched_mul(rand(2,2,2), rand(TB, 10,2,2))

0 commit comments

Comments
 (0)