Skip to content

Commit 6ac5280

Browse files
MillaFleursKD2YCUangeloskath
authored
Fix assigning bool to float16/bfloat16 (#3229)
Co-authored-by: KD2YCU <me@kd2ycu.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
1 parent 572e0a4 commit 6ac5280

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

mlx/types/bf16.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
#define __MLX_BFLOAT_NAN__ 0x7FC0
11+
#define __MLX_BFLOAT_ONE__ 0x3F80
1112

1213
namespace mlx::core {
1314

@@ -29,8 +30,8 @@ struct _MLX_BFloat16 {
2930

3031
// Appease std::vector<bool> for being special
3132
_MLX_BFloat16& operator=(std::vector<bool>::reference x) {
32-
bits_ = x;
33-
return *this;
33+
bits_ = (x) ? __MLX_BFLOAT_ONE__ : 0;
34+
return (*this);
3435
}
3536

3637
_MLX_BFloat16& operator=(const float& x) {

mlx/types/fp16.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
#define __MLX_HALF_NAN__ 0x7D00
11+
#define __MLX_HALF_ONE__ 0x3C00
1112

1213
namespace mlx::core {
1314

@@ -29,8 +30,8 @@ struct _MLX_Float16 {
2930

3031
// Appease std::vector<bool> for being special
3132
_MLX_Float16& operator=(std::vector<bool>::reference x) {
32-
bits_ = x;
33-
return *this;
33+
bits_ = (x) ? __MLX_HALF_ONE__ : 0;
34+
return (*this);
3435
}
3536

3637
_MLX_Float16& operator=(const float& x) {

tests/array_tests.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ TEST_CASE("test array basics") {
105105
CHECK_EQ(x.dtype(), bool_);
106106
CHECK(array_equal(x, array({false, true, false, true})).item<bool>());
107107
}
108+
109+
// Regression: vector<bool>::reference to fp16/bf16 stored raw bits
110+
{
111+
std::vector<bool> data = {true, false, true};
112+
auto bf = array(data.begin(), {3}, bfloat16);
113+
CHECK(array_equal(bf, array({1.0f, 0.0f, 1.0f}, bfloat16)).item<bool>());
114+
115+
auto fp = array(data.begin(), {3}, float16);
116+
CHECK(array_equal(fp, array({1.0f, 0.0f, 1.0f}, float16)).item<bool>());
117+
}
108118
}
109119

110120
TEST_CASE("test array types") {

0 commit comments

Comments
 (0)