Placement of JIT and performance relative to Pytorch #6769
-
Hello everyone, I hope you are all well. I have just started using Jax and have written a simple Multi-Layer Perceptron following the structure provided in the documentation tutorial: https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html Whilst I indeed found that jitting a gradient step of my neural network results in a much faster performance, I found that the same gradient step in Pytorch runs significantly faster (approx 3.6 times faster for GPU / 2.7 times faster for CPU). Is this to do with the way I am using I have attached the code below to recreate the results. The experiment consists of performing 1000 gradient steps for a MLP implemented in Jax and another MLP implemented in Torch. My code was ran from a Google Colab Session. import torch
import numpy as np
import jax.numpy as jnp
import jax.random as random
from jax import jit, grad
from functools import partial
import math
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# MLP IN JAX
class jax_MLP():
def __init__(self,H,key = random.PRNGKey(123)):
keys = random.split(key,num=3)
scale = 1/ math.sqrt(H)
self.a = random.uniform(keys[0],minval = -1,maxval = 1,shape=(1,H)) * scale
self.b = random.uniform(keys[1],minval = -1,maxval = 1,shape=(H,)) * scale
self.w = random.uniform(keys[2],minval = -1,maxval = 1,shape=(H,1)) * scale
def get_params(self):
return [self.a,self.b,self.w]
def forward(self,x,params):
"""
params = [a,b,w]
"""
a,b,w = params
x1 = jnp.dot(x,a) + b #first linear map with bias
x2 = jnp.maximum(0,x1) #ReLU
x3 = jnp.dot(x2,w) #second linear map no bias
return x3
#forward pass and calculate L2 Loss
def forward_pass(self,x,y,params):
preds = self.forward(x,params)
return jnp.mean((preds-y)**2)
@partial(jit,static_argnums=(0,))
def forward_backward(self,x,y,params):
loss = self.forward_pass(x,y,params)
grads = grad(self.forward_pass,argnums=[2])(x,y,params)
return loss,grads
def update(self,grads,lr):
da,db,dw = grads[0]
self.a -= lr*da
self.b -= lr*db
self.w -= lr*dw
# MLP IN TORCH
class torch_MLP(nn.Module):
"""
Multi layer perceptron
"""
def __init__(self,H=100,C = 1):
super(torch_MLP, self).__init__()
self.H = H
self.linear1 = nn.Linear(1,H,bias=True)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(H,C,bias=False)
def forward(self,x):
y = self.linear1(x)
y = self.relu(y)
y = self.linear2(y)
return y
tmlp = torch_MLP(H=1000).to(device)
jmlp = jax_MLP(1000)
#Learning Rate
LR = 1e-4
#PYTORCH params
loss_fn = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(tmlp.parameters(),lr=LR)
# Define a single step of an MLP for both Jax and Torch
def torch_step(x,y):
pred = tmlp(x)
loss = loss_fn(pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def jax_step(x,y):
params = jmlp.get_params()
loss,grads = jmlp.forward_backward(x,y,params)
jmlp.update(grads,LR)
# Functions to run 1000 epochs of gradient steps on random data
def epochs_torch():
for i in range(1000):
x = np.random.randn(100,1)
y = np.random.randn(100,1)
x = torch.Tensor(x).to(device)
y = torch.Tensor(y).to(device)
torch_step(x,y)
def epochs_jax():
for i in range(1000):
x = np.random.randn(100,1)
y = np.random.randn(100,1)
x = jnp.array(x)
y = jnp.array(y)
jax_step(x,y)
%timeit epochs_torch()
%timeit epochs_jax() The results for running on CPU and GPU with a Google Colab session are as follows: CPUPytorch MLP: Hence on CPU Torch is approximately 2.7 times faster GPUPytorch MLP: Hence on GPU torch is approximately 3.6 times faster. Many thanks for any help / pointers. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I think the issue you have is that use of What I would recommend is to follow the philosophy followed by Explicitly what I would recommend is changing
unjitting
which you can then jit, and then the actual timing loop would be
Here I've added in a On my colab instance, this is now slightly faster than the pytorch implementation. |
Beta Was this translation helpful? Give feedback.
I think the issue you have is that use of
static_argnums
withforwards_backwards
. This re-compiles the function whenever the static input changes, which I think will change every time you update your instance ofjax_MLP
.What I would recommend is to follow the philosophy followed by
haiku
and similar packages, where the parameters are passed around explicitly in the optimization loop, instead of updating a model object. This follows the jax philosophy of making functions functional.Explicitly what I would recommend is changing
update
tounjitting
forwards_backwards
…