@@ -406,10 +406,16 @@ mx::array array_from_list_impl(
406406 }
407407 }
408408 case pyfloat: {
409- std::vector<float > vals;
410- fill_vector (pl, vals);
411- return mx::array (
412- vals.begin (), shape, specified_type.value_or (mx::float32));
409+ auto out_type = specified_type.value_or (mx::float32);
410+ if (out_type == mx::float64) {
411+ std::vector<double > vals;
412+ fill_vector (pl, vals);
413+ return mx::array (vals.begin (), shape, out_type);
414+ } else {
415+ std::vector<float > vals;
416+ fill_vector (pl, vals);
417+ return mx::array (vals.begin (), shape, out_type);
418+ }
413419 }
414420 case pycomplex: {
415421 std::vector<std::complex <float >> vals;
@@ -470,7 +476,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
470476 : mx::int32;
471477 return mx::array (val, t.value_or (default_type));
472478 } else if (auto pv = std::get_if<nb::float_>(&v); pv) {
473- return mx::array (nb::cast<float >(*pv), t.value_or (mx::float32));
479+ auto out_type = t.value_or (mx::float32);
480+ if (out_type == mx::float64) {
481+ return mx::array (nb::cast<double >(*pv), out_type);
482+ } else {
483+ return mx::array (nb::cast<float >(*pv), out_type);
484+ }
474485 } else if (auto pv = std::get_if<std::complex <float >>(&v); pv) {
475486 return mx::array (
476487 static_cast <mx::complex64_t >(*pv), t.value_or (mx::complex64));
0 commit comments