JAX-JIT 100x slower compared to Torch (same GPU)? #20287
-
Hello, Here is the torch code: import time
import copy
import numpy as np
import torch
if torch.cuda.is_available():
print(torch.cuda.get_device_name(0))
torch_device = 'cuda'
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch_device = 'cpu'
torch.set_default_tensor_type(torch.DoubleTensor)
float_dtype = np.float64
print(f"TORCH DEVICE: {torch_device}")
# Tesla V100-SXM2-16GB
# TORCH DEVICE: cuda
def grab(var):
return var.detach().cpu().numpy()
def dimless_action(
phi: torch.Tensor,
kap: float,
lam: float,
h: float):
dims = [i+1 for i in range(len(phi.shape)-1)]
action = (1 - 2 * lam) * phi**2 + lam * phi**4 + h * phi
action += -2. * kap * phi * torch.roll(phi, 1, 1)
action += -2. * kap * phi * torch.roll(phi, 1, 2)
return torch.sum(action, dims)
def dimless_force(
phi: torch.Tensor,
kap: float,
lam: float,
h: float):
dims = [i+1 for i in range(len(phi.shape)-1)]
force = 2 * phi * (2 * lam * (1 - phi**2) - 1) - h
force += 2. * kap * (torch.roll(phi, 1, 1) + torch.roll(phi, -1, 1))
force += 2. * kap * (torch.roll(phi, 1, 2) + torch.roll(phi, -1, 2))
return force
def hmc(
phi: torch.Tensor,
action_fn,
force_fn,
dt: float=0.1,
n_steps: int=10):
"""Return batch of updated phi, acceptance rate."""
dims = [i+1 for i in range(len(phi.shape)-1)]
S = action_fn(phi)
new_phi = copy.deepcopy(phi)
momentum =torch.randn(phi.shape) # debug 0.1*torch.ones_like(phi)
hamiltonian = 0.5 * torch.sum(momentum**2, dims) + S
momentum += 0.5 * dt * force_fn(new_phi)
for _ in range(n_steps-1):
new_phi += dt * momentum
momentum += dt * force_fn(new_phi)
new_phi += dt * momentum
momentum += 0.5 * dt * force_fn(new_phi)
new_S = action_fn(new_phi)
d_hamiltonian = 0.5 * torch.sum(momentum**2, dims) + new_S - hamiltonian
accept = torch.rand(len(d_hamiltonian)) < torch.exp(-d_hamiltonian)
#accept = torch.tensor([True]*len(d_hamiltonian)) # debug
a = accept.view(-1,1,1).repeat(1,*phi.shape[1:])
return a * new_phi + ~a * phi, accept
d = 2
L = 8
lattice_shape = [L for _ in range(d)]
kap = 0.27
lam = 0.02
h = 0.0
batch_size = 1000
dt = 0.1
n_steps = 10
phi_init = 0.1 * torch.randn(batch_size, *lattice_shape) In a notebook cell
I get
Now, here is the JAX/JIT code (same type of GPU) import numpy as np
import jax.numpy as jnp
import jax
from functools import partial
@jax.jit
def get_Batch_S(phi,params): # Ok wrt torch
""" get dimless action
phi : tensor Batch x L^d
params : lambda, kapa, h
"""
l,k,h = params
dtot = jnp.ndim(phi) # d: total number of dimensions 1+d
d = range(1,dtot)
action = (1 - 2 * l) * phi**2 + l * phi**4 + h*phi
for mu in d:
action += -2.* k * phi * jnp.roll(phi, 1, mu)
return action.sum(axis=d)
@jax.jit
def get_Batch_force(phi,params): # Ok wrt torch
""" get dimless force = - gradient of action
phi : tensor Batch x L^d
params : lambda, kapa, h
"""
l,k,h = params
dtot = jnp.ndim(phi) # d: total number of dimensions 1+d
d = range(1,dtot)
force = 2 * phi * (2 * l * (1 - phi**2) - 1) - h
for mu in d:
force += 2. * k * (jnp.roll(phi, 1, mu) + jnp.roll(phi, -1, mu))
return force
@partial(jax.jit,static_argnums=(2,3,4))
def hmc_one_step(key,phi,action_fn,force_fn,n_steps=10,dt=0.1):
"""Return batch of updated phi, acceptance rate.
"""
dims = [i+1 for i in range(len(phi.shape)-1)]
key, subkey0, subkey1 = jax.random.split(key,num=3)
S = action_fn(phi)
new_phi = jnp.copy(phi)
momentum = jax.random.normal(subkey0, shape=phi.shape) # debug 0.1 * jnp.ones_like(phi)
hamiltonian = 0.5 * jnp.sum(momentum**2, axis=dims) + S
#leapfrog
def update(i,carry):
new_phi, momentum = carry
new_phi += dt * momentum
momentum += dt * force_fn(new_phi)
return new_phi, momentum
momentum += 0.5 * dt * force_fn(new_phi)
new_phi, momentum = jax.lax.fori_loop(0,n_steps-1,update,(new_phi, momentum))
new_phi += dt * momentum
momentum += 0.5 * dt * force_fn(new_phi)
new_S = action_fn(new_phi)
d_hamiltonian = 0.5 * jnp.sum(momentum**2, axis=dims) + new_S - hamiltonian
#metroplolis
prob = jnp.exp(-d_hamiltonian)
accept = jax.random.uniform(subkey1, shape=prob.shape) < prob
#debug accept = jnp.array([True]*len(prob)) # debug
shape = tuple([phi.shape[0]]+[1 for i in range(len(phi.shape)-1)]) # B,1...,1
accept = accept.reshape(shape)
res = jnp.where(accept, new_phi, phi)
return res, accept
n_iter=1000
batch_size = 10
L = 8
kap = 0.27
lam = 0.02
h = 0.0
key= jax.random.PRNGKey(40)
key,subkey = jax.random.split(key)
phi_init = 0.1*jax.random.normal(subkey,(batch_size,L,L))
params=lam,kap,h
dt = 0.1
n_steps = 10 in a separate nb cell after warm-up using a first call, %timeit res,_ = hmc_one_step(key,\
phi=phi_init,\
action_fn=lambda phi: get_Batch_S(phi, params),\
force_fn=lambda phi: get_Batch_force(phi, params),\
dt=dt,\
n_steps=n_steps\
) I get
If
Now, in practice calling 1000-times There may be a CPU-GPU transfert induced somewhere, as the code is getting the same answer at the end on Any idea is welcome. Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Someone can help? |
Beta Was this translation helpful? Give feedback.
Hi, digging a little I have found the problem on my JAX code: it comes from the way I introducing the
action_fn and force_fn as
I was force to use static arguments for
action_fn, force_fn
. I have redesigned the code to avoid such static_argumsand now the JAX-JIT code is 3 times faster than the corresponding torch code.