Skip to content

Commit 6e8a35f

Browse files
james-martensGoogle-ML-Automation
authored andcommitted
Adding support for copy_p primitive to jet.
PiperOrigin-RevId: 694296952
1 parent 1f1d27d commit 6e8a35f

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

jax/experimental/jet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def linear_prop(prim, primals_in, series_in, **params):
329329
deflinear(lax.reduce_sum_p)
330330
deflinear(lax.reduce_window_sum_p)
331331
deflinear(lax.fft_p)
332+
deflinear(lax.copy_p)
332333
deflinear(dispatch.device_put_p)
333334

334335
def _dynamic_slice_jet_rule(primals_in, series_in, **params):

tests/jet_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def test_cummin(self): self.unary_check(partial(lax.cummin, axis=0))
319319
def test_dynamic_slice(self): self.unary_check(partial(lax.dynamic_slice, start_indices=(1,2), slice_sizes=(1,1)))
320320
@jtu.skip_on_devices("tpu")
321321
def test_dynamic_update_slice(self): self.unary_check(partial(lax.dynamic_update_slice, start_indices=(1,2), update=np.arange(6.0).reshape(2, 3)))
322+
@jtu.skip_on_devices("tpu")
323+
def test_copy(self): self.unary_check(jnp.array)
322324

323325

324326
@jtu.skip_on_devices("tpu")

0 commit comments

Comments
 (0)