Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
Enzyme-JAX is a C++ project whose original aim was to integrate the Enzyme automatic differentiation tool [1] with JAX, enabling automatic differentiation of external C++ code within JAX. It has since expanded to incorporate Polygeist's [2] high performance raising, parallelization, cross compilation workflow, as well as numerous tensor, linear algerba, and communication optimizations. The project uses LLVM's MLIR framework for intermediate representation and transformation of code. As Enzyme is language-agnostic, this can be extended for arbitrary programming
languages (Julia, Swift, Fortran, Rust, and even Python)!

You can use
# Usage Examples

## Usage with C++

You can use `cpp_call` to differentiate external C++ code:

```python
from enzyme_ad.jax import cpp_call
Expand All @@ -29,6 +33,43 @@ primals, f_vjp = jax.vjp(something, ones)
(grads,) = f_vjp((x,))
```

## Usage with Jax

You can also use Enzyme to optimize and differentiate vanilla JAX code using the `@enzyme_jax_ir` decorator. This allows applying Enzyme's optimizations and AD to standard JAX functions.

```python
from enzyme_ad.jax import enzyme_jax_ir
import jax
import jax.numpy as jnp

# Apply Enzyme optimizations and AD support
@jax.jit
@enzyme_jax_ir
def add_one(x, y):
return x + 1 + y

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([10.0, 20.0, 30.0])

# Run the function
result = add_one(x, y)
print("Result:", result)

# Forward-mode AD (JVP)
primals, tangents = jax.jvp(
add_one,
(x, y),
(jnp.array([0.1, 0.2, 0.3]), jnp.array([50.0, 70.0, 110.0])),
)
print("Primals:", primals)
print("Tangents:", tangents)

# Reverse-mode AD (VJP)
primals, f_vjp = jax.vjp(add_one, x, y)
grads = f_vjp(jnp.array([500.0, 700.0, 110.0]))
print("Gradients:", grads)
```

# Installation

The easiest way to install is using pip.
Expand Down