Skip to content

Commit c4767d1

Browse files
authored
fix addmm cpu (#2699)
1 parent 895217f commit c4767d1

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

mlx/backend/cpu/gemms/bnns.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2023-2024 Apple Inc.
2-
32
#include <Accelerate/Accelerate.h>
43

54
#include "mlx/array.h"
@@ -49,9 +48,15 @@ void matmul_bnns(
4948
size_t K = a_shape[ndim - 1];
5049

5150
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
52-
5351
#pragma GCC diagnostic push
5452
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
53+
if (beta != 1.0 && beta != 0.0) {
54+
// scale the output
55+
for (auto i = 0; i < batch_size * M * N; ++i) {
56+
out[i] *= beta;
57+
}
58+
beta = 1.0;
59+
}
5560
const BNNSLayerParametersBroadcastMatMul gemm_params{
5661
/* float alpha = */ alpha,
5762
/* float beta = */ beta,

python/tests/test_blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,8 @@ def test_addmm(self):
717717
c = mx.ones((32, 32)).astype(t)
718718
a = mx.random.uniform(shape=(32, 32)).astype(t)
719719
b = mx.random.uniform(shape=(32, 32)).astype(t)
720-
out = mx.addmm(c, a, b)
721-
expected = a @ b + c
720+
out = mx.addmm(c, a, b, alpha=0.5, beta=2.0)
721+
expected = 0.5 * (a @ b) + 2.0 * c
722722
self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol))
723723

724724
def test_addmm_grad(self):

0 commit comments

Comments
 (0)