Skip to content

Commit 8879ee0

Browse files
authored
Support more Numpy interfaces for masked_scatter (#2832)
1 parent 6e762fe commit 8879ee0

File tree

4 files changed

+34
-10
lines changed

4 files changed

+34
-10
lines changed

docs/src/usage/indexing.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ assignments, ``updates`` must provide at least as many elements as there are
179179
180180
Boolean masks follow NumPy semantics:
181181
182-
- The mask shape must match the shape of the axes it indexes exactly. No mask
183-
broadcasting occurs.
182+
- The mask shape must match the shape of the axes it indexes exactly. The only
183+
exception is a scalar boolean mask, which broadcasts to the full array.
184184
- Any axes not covered by the mask are taken in full.
185185
186186
.. code-block:: shell

mlx/ops.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3466,10 +3466,8 @@ array masked_scatter(
34663466
if (mask.dtype() != bool_) {
34673467
throw std::invalid_argument("[masked_scatter] The mask has to be boolean.");
34683468
}
3469-
if (mask.ndim() == 0) {
3470-
throw std::invalid_argument(
3471-
"[masked_scatter] Scalar masks are not supported.");
3472-
} else if (mask.ndim() > a.ndim()) {
3469+
3470+
if (mask.ndim() > a.ndim()) {
34733471
throw std::invalid_argument(
34743472
"[masked_scatter] The mask cannot have more dimensions than the target.");
34753473
}

python/src/indexing.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ auto mlx_slice_update(
766766
const nb::object& obj,
767767
const ScalarOrArray& v) {
768768
// Can't route to slice update if not slice, tuple, or int
769-
if (src.ndim() == 0 ||
769+
if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
770770
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
771771
!nb::isinstance<nb::int_>(obj))) {
772772
return std::make_pair(false, src);
@@ -888,7 +888,9 @@ auto mlx_slice_update(
888888

889889
std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
890890
using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
891-
if (nb::isinstance<mx::array>(obj)) {
891+
if (nb::isinstance<nb::bool_>(obj)) {
892+
return mx::array(nb::cast<bool>(obj), mx::bool_);
893+
} else if (nb::isinstance<mx::array>(obj)) {
892894
auto mask = nb::cast<mx::array>(obj);
893895
if (mask.dtype() == mx::bool_) {
894896
return mask;
@@ -898,6 +900,11 @@ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
898900
if (mask.dtype() == nb::dtype<bool>()) {
899901
return nd_array_to_mlx(mask, mx::bool_);
900902
}
903+
} else if (nb::isinstance<nb::list>(obj)) {
904+
auto mask = array_from_list(nb::cast<nb::list>(obj), {});
905+
if (mask.dtype() == mx::bool_) {
906+
return mask;
907+
}
901908
}
902909
return std::nullopt;
903910
}

python/tests/test_array.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,8 +1929,27 @@ def test_setitem_with_list(self):
19291929
self.assertTrue(np.array_equal(a, anp))
19301930

19311931
def test_setitem_with_boolean_mask(self):
1932-
mask_np = np.zeros((10, 10), dtype=bool)
1933-
mx.arange(1000).reshape(10, 10, 10)[mask_np] = 0
1932+
# Python list mask
1933+
a = mx.array([1.0, 2.0, 3.0])
1934+
mask = [True, False, True]
1935+
src = mx.array([5.0, 6.0])
1936+
expected = mx.array([5.0, 2.0, 6.0])
1937+
a[mask] = src
1938+
self.assertTrue(mx.array_equal(a, expected))
1939+
1940+
# mx.array scalar mask
1941+
a = mx.array([1.0, 2.0, 3.0])
1942+
mask = mx.array(True)
1943+
expected = mx.array([5.0, 5.0, 5.0])
1944+
a[mask] = 5.0
1945+
self.assertTrue(mx.array_equal(a, expected))
1946+
1947+
# scalar mask
1948+
a = mx.array([1.0, 2.0, 3.0])
1949+
mask = True
1950+
expected = mx.array([5.0, 5.0, 5.0])
1951+
a[mask] = 5.0
1952+
self.assertTrue(mx.array_equal(a, expected))
19341953

19351954
mask_np = np.zeros((1, 10, 10), dtype=bool)
19361955
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)