Skip to content

Commit 7c14344

Browse files
authored
Add jvp op (#21720)
* add jvp op * add jvp op * bug fix * add symbolic call * fix doc.
1 parent 94ca6ef commit 7c14344

File tree

11 files changed

+185
-0
lines changed

11 files changed

+185
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from keras.src.ops.linalg import eig as eig
3838
from keras.src.ops.linalg import eigh as eigh
3939
from keras.src.ops.linalg import inv as inv
40+
from keras.src.ops.linalg import jvp as jvp
4041
from keras.src.ops.linalg import lstsq as lstsq
4142
from keras.src.ops.linalg import lu_factor as lu_factor
4243
from keras.src.ops.linalg import norm as norm

keras/api/_tf_keras/keras/ops/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.ops.linalg import eig as eig
1111
from keras.src.ops.linalg import eigh as eigh
1212
from keras.src.ops.linalg import inv as inv
13+
from keras.src.ops.linalg import jvp as jvp
1314
from keras.src.ops.linalg import lstsq as lstsq
1415
from keras.src.ops.linalg import lu_factor as lu_factor
1516
from keras.src.ops.linalg import norm as norm

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from keras.src.ops.linalg import eig as eig
3838
from keras.src.ops.linalg import eigh as eigh
3939
from keras.src.ops.linalg import inv as inv
40+
from keras.src.ops.linalg import jvp as jvp
4041
from keras.src.ops.linalg import lstsq as lstsq
4142
from keras.src.ops.linalg import lu_factor as lu_factor
4243
from keras.src.ops.linalg import norm as norm

keras/api/ops/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.ops.linalg import eig as eig
1111
from keras.src.ops.linalg import eigh as eigh
1212
from keras.src.ops.linalg import inv as inv
13+
from keras.src.ops.linalg import jvp as jvp
1314
from keras.src.ops.linalg import lstsq as lstsq
1415
from keras.src.ops.linalg import lu_factor as lu_factor
1516
from keras.src.ops.linalg import norm as norm

keras/src/backend/jax/linalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,7 @@ def lstsq(a, b, rcond=None):
9797
a = convert_to_tensor(a)
9898
b = convert_to_tensor(b)
9999
return jnp.linalg.lstsq(a, b, rcond=rcond)[0]
100+
101+
102+
def jvp(fun, primals, tangents, has_aux=False):
103+
return jax.jvp(fun, primals, tangents, has_aux=has_aux)

keras/src/backend/numpy/linalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,7 @@ def lstsq(a, b, rcond=None):
9696
a = convert_to_tensor(a)
9797
b = convert_to_tensor(b)
9898
return np.linalg.lstsq(a, b, rcond=rcond)[0]
99+
100+
101+
def jvp(fun, primals, tangents, has_aux=False):
102+
raise NotImplementedError("JVP is not supported by the Numpy backend.")

keras/src/backend/openvino/linalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ def svd(x, full_matrices=True, compute_uv=True):
5656

5757
def lstsq(a, b, rcond=None):
5858
raise NotImplementedError("`lstsq` is not supported with openvino backend")
59+
60+
61+
def jvp(fun, primals, tangents, has_aux=False):
62+
raise NotImplementedError("`jvp` is not supported with openvino backend")

keras/src/backend/tensorflow/linalg.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,27 @@ def lstsq(a, b, rcond=None):
244244
if b_orig_ndim == 1:
245245
x = tf.reshape(x, [-1])
246246
return x
247+
248+
249+
def jvp(fun, primals, tangents, has_aux=False):
250+
primal_flat = tf.nest.flatten(primals)
251+
tangent_flat = tf.nest.flatten(tangents)
252+
253+
tangent_flat = [
254+
tf.cast(t, p.dtype) for t, p in zip(tangent_flat, primal_flat)
255+
]
256+
257+
with tf.autodiff.ForwardAccumulator(primal_flat, tangent_flat) as acc:
258+
if has_aux:
259+
primals_out, aux = fun(*primals)
260+
else:
261+
primals_out = fun(*primals)
262+
263+
primals_out_flat = tf.nest.flatten(primals_out)
264+
tangents_out_flat = [acc.jvp(po) for po in primals_out_flat]
265+
266+
tangents_out = tf.nest.pack_sequence_as(primals_out, tangents_out_flat)
267+
268+
if has_aux:
269+
return primals_out, tangents_out, aux
270+
return primals_out, tangents_out

keras/src/backend/torch/linalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,7 @@ def lstsq(a, b, rcond=None):
8080
a = convert_to_tensor(a)
8181
b = convert_to_tensor(b)
8282
return torch.linalg.lstsq(a, b, rcond=rcond)[0]
83+
84+
85+
def jvp(fun, primals, tangents, has_aux=False):
86+
return torch.func.jvp(fun, primals, tangents, has_aux=has_aux)

keras/src/ops/linalg.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from keras.src import backend
2+
from keras.src import tree
23
from keras.src.api_export import keras_export
34
from keras.src.backend import KerasTensor
45
from keras.src.backend import any_symbolic_tensors
@@ -732,3 +733,95 @@ def _assert_a_b_compat(a, b):
732733
"Expected `a.shape[-1] == b.shape[-1]`. "
733734
f"Received: a.shape={a.shape}, b.shape={b.shape}"
734735
)
736+
737+
738+
class JVP(Operation):
739+
def __init__(self, has_aux=False, *, name=None):
740+
super().__init__(name=name)
741+
self.has_aux = has_aux
742+
743+
def call(self, fun, primals, tangents):
744+
"""Computes the JVP of `fun` at `primals` along `tangents`.
745+
746+
Args:
747+
fun: A callable that takes tensors (or nested structures) as input
748+
and returns a tensor (or nested structure) as output.
749+
primals: Input tensors (or nested structures) at which the Jacobian
750+
of `fun` is evaluated.
751+
tangents: Tensors (or nested structures) representing the direction
752+
vectors for the JVP. Must have the same structure as
753+
`primals`.
754+
755+
Returns:
756+
If `has_aux` is False:
757+
A tuple (primals_out, tangents_out) where:
758+
- primals_out: Output of `fun(*primals)`
759+
- tangents_out: JVP of `fun` at `primals` along `tangents`
760+
If `has_aux` is True:
761+
A tuple (primals_out, tangents_out, aux) where:
762+
- aux: Auxiliary data returned by `fun`
763+
"""
764+
return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux)
765+
766+
def compute_output_spec(self, fun, primals, tangents):
767+
# Infer primal output spec
768+
if self.has_aux:
769+
primals_out_spec, aux_spec = backend.compute_output_spec(
770+
fun, *primals
771+
)
772+
else:
773+
primals_out_spec = backend.compute_output_spec(fun, *primals)
774+
775+
# Tangents output should match primals output in structure and shape
776+
tangents_out_spec = tree.map_structure(
777+
lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec
778+
)
779+
780+
if self.has_aux:
781+
return primals_out_spec, tangents_out_spec, aux_spec
782+
return primals_out_spec, tangents_out_spec
783+
784+
785+
@keras_export(["keras.ops.jvp", "keras.ops.linalg.jvp"])
786+
def jvp(fun, primals, tangents, has_aux=False):
787+
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
788+
Args:
789+
fun: Function to be differentiated. Its arguments should be arrays,
790+
scalars, or standard Python containers of arrays or scalars. It
791+
should return an array, scalar, or standard Python container of
792+
arrays or scalars.
793+
primals: The primal values at which the Jacobian of `fun` should be
794+
evaluated. Should be either a tuple or a list of arguments,
795+
and its length should be equal to the number of positional
796+
parameters of `fun`.
797+
tangents: The tangent vector for which the Jacobian-vector product
798+
should be evaluated. Should be either a tuple or a list of
799+
tangents, with the same tree structure and array shapes as
800+
`primals`.
801+
has_aux: Optional, bool. Indicates whether `fun` returns a pair where
802+
the first element is considered the output of the mathematical
803+
function to be differentiated and the second element is
804+
auxiliary data. Default is False.
805+
806+
Returns:
807+
If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair,
808+
where `primals_out` is `fun(*primals)`, and `tangents_out` is the
809+
Jacobian-vector product of `fun` evaluated at `primals` with
810+
`tangents`. The `tangents_out` value has the same Python tree
811+
structure and shapes as `primals_out`.
812+
813+
If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`)
814+
tuple where `aux` is the auxiliary data returned by `fun`.
815+
816+
Example:
817+
>>> from keras import ops
818+
>>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)
819+
>>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,))
820+
>>> primals
821+
0.09983342
822+
>>> tangents
823+
0.19900084
824+
"""
825+
if any_symbolic_tensors((primals, tangents)):
826+
return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents)
827+
return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux)

0 commit comments

Comments
 (0)