|
1 | 1 | from keras.src import backend
|
| 2 | +from keras.src import tree |
2 | 3 | from keras.src.api_export import keras_export
|
3 | 4 | from keras.src.backend import KerasTensor
|
4 | 5 | from keras.src.backend import any_symbolic_tensors
|
@@ -732,3 +733,95 @@ def _assert_a_b_compat(a, b):
|
732 | 733 | "Expected `a.shape[-1] == b.shape[-1]`. "
|
733 | 734 | f"Received: a.shape={a.shape}, b.shape={b.shape}"
|
734 | 735 | )
|
| 736 | + |
| 737 | + |
| 738 | +class JVP(Operation): |
| 739 | + def __init__(self, has_aux=False, *, name=None): |
| 740 | + super().__init__(name=name) |
| 741 | + self.has_aux = has_aux |
| 742 | + |
| 743 | + def call(self, fun, primals, tangents): |
| 744 | + """Computes the JVP of `fun` at `primals` along `tangents`. |
| 745 | +
|
| 746 | + Args: |
| 747 | + fun: A callable that takes tensors (or nested structures) as input |
| 748 | + and returns a tensor (or nested structure) as output. |
| 749 | + primals: Input tensors (or nested structures) at which the Jacobian |
| 750 | + of `fun` is evaluated. |
| 751 | + tangents: Tensors (or nested structures) representing the direction |
| 752 | + vectors for the JVP. Must have the same structure as |
| 753 | + `primals`. |
| 754 | +
|
| 755 | + Returns: |
| 756 | + If `has_aux` is False: |
| 757 | + A tuple (primals_out, tangents_out) where: |
| 758 | + - primals_out: Output of `fun(*primals)` |
| 759 | + - tangents_out: JVP of `fun` at `primals` along `tangents` |
| 760 | + If `has_aux` is True: |
| 761 | + A tuple (primals_out, tangents_out, aux) where: |
| 762 | + - aux: Auxiliary data returned by `fun` |
| 763 | + """ |
| 764 | + return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux) |
| 765 | + |
| 766 | + def compute_output_spec(self, fun, primals, tangents): |
| 767 | + # Infer primal output spec |
| 768 | + if self.has_aux: |
| 769 | + primals_out_spec, aux_spec = backend.compute_output_spec( |
| 770 | + fun, *primals |
| 771 | + ) |
| 772 | + else: |
| 773 | + primals_out_spec = backend.compute_output_spec(fun, *primals) |
| 774 | + |
| 775 | + # Tangents output should match primals output in structure and shape |
| 776 | + tangents_out_spec = tree.map_structure( |
| 777 | + lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec |
| 778 | + ) |
| 779 | + |
| 780 | + if self.has_aux: |
| 781 | + return primals_out_spec, tangents_out_spec, aux_spec |
| 782 | + return primals_out_spec, tangents_out_spec |
| 783 | + |
| 784 | + |
| 785 | +@keras_export(["keras.ops.jvp", "keras.ops.linalg.jvp"]) |
| 786 | +def jvp(fun, primals, tangents, has_aux=False): |
| 787 | + """Computes a (forward-mode) Jacobian-vector product of `fun`. |
| 788 | + Args: |
| 789 | + fun: Function to be differentiated. Its arguments should be arrays, |
| 790 | + scalars, or standard Python containers of arrays or scalars. It |
| 791 | + should return an array, scalar, or standard Python container of |
| 792 | + arrays or scalars. |
| 793 | + primals: The primal values at which the Jacobian of `fun` should be |
| 794 | + evaluated. Should be either a tuple or a list of arguments, |
| 795 | + and its length should be equal to the number of positional |
| 796 | + parameters of `fun`. |
| 797 | + tangents: The tangent vector for which the Jacobian-vector product |
| 798 | + should be evaluated. Should be either a tuple or a list of |
| 799 | + tangents, with the same tree structure and array shapes as |
| 800 | + `primals`. |
| 801 | + has_aux: Optional, bool. Indicates whether `fun` returns a pair where |
| 802 | + the first element is considered the output of the mathematical |
| 803 | + function to be differentiated and the second element is |
| 804 | + auxiliary data. Default is False. |
| 805 | +
|
| 806 | + Returns: |
| 807 | + If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair, |
| 808 | + where `primals_out` is `fun(*primals)`, and `tangents_out` is the |
| 809 | + Jacobian-vector product of `fun` evaluated at `primals` with |
| 810 | + `tangents`. The `tangents_out` value has the same Python tree |
| 811 | + structure and shapes as `primals_out`. |
| 812 | +
|
| 813 | + If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`) |
| 814 | + tuple where `aux` is the auxiliary data returned by `fun`. |
| 815 | +
|
| 816 | + Example: |
| 817 | + >>> from keras import ops |
| 818 | + >>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2) |
| 819 | + >>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,)) |
| 820 | + >>> primals |
| 821 | + 0.09983342 |
| 822 | + >>> tangents |
| 823 | + 0.19900084 |
| 824 | + """ |
| 825 | + if any_symbolic_tensors((primals, tangents)): |
| 826 | + return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents) |
| 827 | + return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux) |
0 commit comments