We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 069a216 commit a80505bCopy full SHA for a80505b
lib/mpsgraphs/matmul.jl
@@ -34,7 +34,7 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Nu
34
afteralpha = if isone(alpha)
35
matmul
36
else
37
- alphatensor = constantWithScalar(graph, alpha, Tc)
+ alphatensor = constantWithScalar(graph, alpha, Float32)
38
multiplicationWithPrimaryTensor(graph, alphatensor, matmul)
39
end
40
@@ -43,8 +43,9 @@ function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab}, b::MtlArray{Tab}, alpha::Nu
43
44
placeC = placeholderTensor(graph, size(c), Tc)
45
feeds[placeC] = outputTensorData
46
- betatensor = constantWithScalar(graph, beta, Tc)
47
- betaC = multiplicationWithPrimaryTensor(graph, betatensor, placeC)
+ betatensor = constantWithScalar(graph, beta, Float32)
+ castplaceC = castTensor(graph, placeC, Float32, "castplaceC")
48
+ betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC)
49
additionWithPrimaryTensor(graph, afteralpha, betaC)
50
51
0 commit comments