Skip to content
Discussion options

You must be logged in to vote

I managed to make a slightly smaller reproducible example of the problem. The issue is that the memory usage for reverse mode autodiff explodes.

from functools import partial
import numpy as np
import jax.numpy as jnp
from jax import random, jit, vmap
from jax import jacfwd, jacrev
from jax import lax
from jax import make_jaxpr
from jax.config import config

config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'gpu')

@jit
def lens_eq(z):
    zbar = jnp.conjugate(z)
    return z - 1. / zbar

@jit
def lens_eq_det_jac(z):
    zbar = jnp.conjugate(z)
    return 1.0 - 1.0 / jnp.abs(zbar**2)


@jit
def images_point_source(w):
    w_abs_sq = jnp.abs(w) ** 2
    w_bar = jnp.c…

Replies: 1 comment 13 replies

Comment options

You must be logged in to vote
13 replies
@YouJiacheng
Comment options

@YouJiacheng
Comment options

@fbartolic
Comment options

@YouJiacheng
Comment options

@fbartolic
Comment options

Answer selected by fbartolic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants