Skip to content

Commit a80505b

Browse files
committed
Fix cast to Float32 when beta != 0
1 parent 069a216 commit a80505b

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

lib/mpsgraphs/matmul.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Nu
3434
afteralpha = if isone(alpha)
3535
matmul
3636
else
37-
alphatensor = constantWithScalar(graph, alpha, Tc)
37+
alphatensor = constantWithScalar(graph, alpha, Float32)
3838
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
3939
end
4040

@@ -43,8 +43,9 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Nu
4343
else
4444
placeC = placeholderTensor(graph, size(c), Tc)
4545
feeds[placeC] = outputTensorData
46-
betatensor = constantWithScalar(graph, beta, Tc)
47-
betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC)
46+
betatensor = constantWithScalar(graph, beta, Float32)
47+
castplaceC = castTensor(graph, placeC, Float32, "castplaceC")
48+
betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC)
4849
additionWithPrimaryTensor(graph, afteralpha, betaC)
4950
end
5051

0 commit comments

Comments
 (0)