Skip to content

Commit eff0e31

Browse files
authored
Fix export scatters (#2852)
1 parent 6c5785b commit eff0e31

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

mlx/export.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ struct PrimitiveFactory {
382382
SERIALIZE_PRIMITIVE(LogicalOr),
383383
SERIALIZE_PRIMITIVE(LogAddExp),
384384
SERIALIZE_PRIMITIVE(LogSumExp),
385+
SERIALIZE_PRIMITIVE(MaskedScatter),
385386
SERIALIZE_PRIMITIVE(Matmul),
386387
SERIALIZE_PRIMITIVE(Maximum),
387388
SERIALIZE_PRIMITIVE(Minimum),

mlx/primitives.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,13 +1871,13 @@ class Scatter : public UnaryPrimitive {
18711871
const char* name() const override {
18721872
switch (reduce_type_) {
18731873
case Sum:
1874-
return "ScatterSum";
1874+
return "Scatter Sum";
18751875
case Prod:
1876-
return "ScatterProd";
1876+
return "Scatter Prod";
18771877
case Min:
1878-
return "ScatterMin";
1878+
return "Scatter Min";
18791879
case Max:
1880-
return "ScatterMax";
1880+
return "Scatter Max";
18811881
case None:
18821882
return "Scatter";
18831883
}
@@ -1910,7 +1910,7 @@ class ScatterAxis : public UnaryPrimitive {
19101910
const char* name() const override {
19111911
switch (reduce_type_) {
19121912
case Sum:
1913-
return "ScatterAxisSum";
1913+
return "ScatterAxis Sum";
19141914
case None:
19151915
return "ScatterAxis";
19161916
}

python/tests/test_export_import.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,19 @@ def fun(y):
596596
for y in ys:
597597
self.assertEqual(imported(y)[0].item(), fun(y).item())
598598

599+
def test_export_import_scatter_sum(self):
600+
def fun(x, y, z):
601+
return x.at[y].add(z)
602+
603+
x = mx.array([1, 2, 3])
604+
y = mx.array([0, 0, 1])
605+
z = mx.array([1, 1, 1])
606+
path = os.path.join(self.test_dir, "fn.mlxfn")
607+
mx.export_function(path, fun, x, y, z)
608+
609+
imported = mx.import_function(path)
610+
self.assertTrue(mx.array_equal(imported(x, y, z)[0], fun(x, y, z)))
611+
599612

600613
if __name__ == "__main__":
601614
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)