Skip to content

Commit d8ceae7

Browse files
authored
Reduce JVP (#2854)
1 parent eff0e31 commit d8ceae7

File tree

3 files changed

+73
-6
lines changed

3 files changed

+73
-6
lines changed

mlx/primitives.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
38463846
}
38473847
}
38483848

3849+
std::vector<array> Reduce::jvp(
3850+
const std::vector<array>& primals,
3851+
const std::vector<array>& tangents,
3852+
const std::vector<int>& argnums) {
3853+
auto in = primals[0];
3854+
auto s = stream();
3855+
3856+
auto grad_op = [&s, reduce_type = reduce_type_](
3857+
const array& x, const array& tan, int axis) {
3858+
if (reduce_type == Reduce::Min) {
3859+
auto idx = argmin(x, axis, true, s);
3860+
return take_along_axis(tan, idx, axis, s);
3861+
} else if (reduce_type == Reduce::Max) {
3862+
auto idx = argmax(x, axis, true, s);
3863+
return take_along_axis(tan, idx, axis, s);
3864+
} else {
3865+
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
3866+
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
3867+
auto out = multiply(multiply(p1, p2, s), tan, s);
3868+
return sum(out, axis, true, s);
3869+
}
3870+
};
3871+
3872+
auto tan = tangents[0];
3873+
if (reduce_type_ == Reduce::Sum) {
3874+
return {sum(tan, axes_, true, s)};
3875+
} else {
3876+
if (axes_.size() > 1) {
3877+
std::vector<int> transpose_to;
3878+
{
3879+
// Find the transpose needed to move axes_ to the back.
3880+
int j = 0;
3881+
for (int i = 0; i < in.ndim(); i++) {
3882+
if (j < axes_.size() && axes_[j] == i) {
3883+
j++;
3884+
} else {
3885+
transpose_to.push_back(i);
3886+
}
3887+
}
3888+
for (auto ax : axes_) {
3889+
transpose_to.push_back(ax);
3890+
}
3891+
}
3892+
3893+
int start_ax = in.ndim() - axes_.size();
3894+
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
3895+
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
3896+
3897+
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
3898+
return {expand_dims(grad, axes_, s)};
3899+
} else {
3900+
return {grad_op(in, tan, axes_[0])};
3901+
}
3902+
}
3903+
}
3904+
38493905
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
38503906
const std::vector<array>& inputs,
38513907
const std::vector<int>& axes) {

mlx/primitives.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
17511751
void eval_gpu(const std::vector<array>& inputs, array& out) override;
17521752

17531753
DEFINE_VMAP()
1754-
1755-
std::vector<array> vjp(
1756-
const std::vector<array>& primals,
1757-
const std::vector<array>& cotangents,
1758-
const std::vector<int>& argnums,
1759-
const std::vector<array>& outputs) override;
1754+
DEFINE_GRADS();
17601755

17611756
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
17621757

python/tests/test_autograd.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,22 @@ def loss_fn(model):
798798
grad_fn(model)
799799
self.assertEqual(model[1].item(), 2.0)
800800

801+
def test_reduce_jvp(self):
802+
a = mx.arange(4)
803+
b = mx.array([3, 2, 1, 0])
804+
805+
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
806+
self.assertEqual(jout[0].item(), 6)
807+
808+
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
809+
self.assertEqual(jout[0].item(), 18)
810+
811+
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
812+
self.assertEqual(jout[0].item(), 3)
813+
814+
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
815+
self.assertEqual(jout[0].item(), 0)
816+
801817

802818
if __name__ == "__main__":
803819
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)