Skip to content

Commit cacbdbf

Browse files
authored
Fix init from double (#2861)
1 parent 193cdcd commit cacbdbf

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

python/src/convert.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,16 @@ mx::array array_from_list_impl(
406406
}
407407
}
408408
case pyfloat: {
409-
std::vector<float> vals;
410-
fill_vector(pl, vals);
411-
return mx::array(
412-
vals.begin(), shape, specified_type.value_or(mx::float32));
409+
auto out_type = specified_type.value_or(mx::float32);
410+
if (out_type == mx::float64) {
411+
std::vector<double> vals;
412+
fill_vector(pl, vals);
413+
return mx::array(vals.begin(), shape, out_type);
414+
} else {
415+
std::vector<float> vals;
416+
fill_vector(pl, vals);
417+
return mx::array(vals.begin(), shape, out_type);
418+
}
413419
}
414420
case pycomplex: {
415421
std::vector<std::complex<float>> vals;
@@ -470,7 +476,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
470476
: mx::int32;
471477
return mx::array(val, t.value_or(default_type));
472478
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
473-
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
479+
auto out_type = t.value_or(mx::float32);
480+
if (out_type == mx::float64) {
481+
return mx::array(nb::cast<double>(*pv), out_type);
482+
} else {
483+
return mx::array(nb::cast<float>(*pv), out_type);
484+
}
474485
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
475486
return mx::array(
476487
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));

python/tests/test_array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,14 @@ def test_construction_from_lists(self):
434434
x = mx.array([0, 4294967295], dtype=mx.float32)
435435
self.assertTrue(np.array_equal(x, xnp))
436436

437+
def test_double_keeps_precision(self):
438+
x = 39.14223403241
439+
out = mx.array(x, dtype=mx.float64).item()
440+
self.assertEqual(out, x)
441+
442+
out = mx.array([x], dtype=mx.float64).item()
443+
self.assertEqual(out, x)
444+
437445
def test_construction_from_lists_of_mlx_arrays(self):
438446
dtypes = [
439447
mx.bool_,

0 commit comments

Comments
 (0)