@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
38463846 }
38473847}
38483848
3849+ std::vector<array> Reduce::jvp (
3850+ const std::vector<array>& primals,
3851+ const std::vector<array>& tangents,
3852+ const std::vector<int >& argnums) {
3853+ auto in = primals[0 ];
3854+ auto s = stream ();
3855+
3856+ auto grad_op = [&s, reduce_type = reduce_type_](
3857+ const array& x, const array& tan, int axis) {
3858+ if (reduce_type == Reduce::Min) {
3859+ auto idx = argmin (x, axis, true , s);
3860+ return take_along_axis (tan, idx, axis, s);
3861+ } else if (reduce_type == Reduce::Max) {
3862+ auto idx = argmax (x, axis, true , s);
3863+ return take_along_axis (tan, idx, axis, s);
3864+ } else {
3865+ auto p1 = cumprod (x, axis, /* reverse=*/ false , /* inclusive=*/ false , s);
3866+ auto p2 = cumprod (x, axis, /* reverse=*/ true , /* inclusive=*/ false , s);
3867+ auto out = multiply (multiply (p1, p2, s), tan, s);
3868+ return sum (out, axis, true , s);
3869+ }
3870+ };
3871+
3872+ auto tan = tangents[0 ];
3873+ if (reduce_type_ == Reduce::Sum) {
3874+ return {sum (tan, axes_, true , s)};
3875+ } else {
3876+ if (axes_.size () > 1 ) {
3877+ std::vector<int > transpose_to;
3878+ {
3879+ // Find the transpose needed to move axes_ to the back.
3880+ int j = 0 ;
3881+ for (int i = 0 ; i < in.ndim (); i++) {
3882+ if (j < axes_.size () && axes_[j] == i) {
3883+ j++;
3884+ } else {
3885+ transpose_to.push_back (i);
3886+ }
3887+ }
3888+ for (auto ax : axes_) {
3889+ transpose_to.push_back (ax);
3890+ }
3891+ }
3892+
3893+ int start_ax = in.ndim () - axes_.size ();
3894+ in = flatten (transpose (in, transpose_to, s), start_ax, -1 , s);
3895+ tan = flatten (transpose (tan, transpose_to, s), start_ax, -1 , s);
3896+
3897+ auto grad = squeeze (grad_op (in, tan, -1 ), -1 , s);
3898+ return {expand_dims (grad, axes_, s)};
3899+ } else {
3900+ return {grad_op (in, tan, axes_[0 ])};
3901+ }
3902+ }
3903+ }
3904+
38493905std::pair<std::vector<array>, std::vector<int >> Reduce::vmap (
38503906 const std::vector<array>& inputs,
38513907 const std::vector<int >& axes) {
0 commit comments