How to vmap
(vectorize/speed up?) a function orginally for an object and now for a list of objects?
#11690
-
Dear all: import time
from jax import jacfwd
class point():
def __init__(self,x):
self.x = x
#A line with 2 points
class line():
def __init__(self,pt1,pt2):
self.inode = pt1
self.jnode = pt2
#A function calculates the length between two 1-D points
def L(x1,x2):
return x1-x2
#A function for a line object
def Line_sensitivity(a_line):
i_node = a_line.inode #first node of this line
j_node = a_line.jnode #second node
xi = i_node.x
xj = j_node.x
#calculate the sensitivity of the length of the line
return jacfwd(L)(xi,xj) Now let's generate 5000 lines and record the time needed #Generate 5000 lines
pts_coord = np.random.rand(5000,2) #5000 lines
line_list = [] #list for lines
for i in range(pts_coord.shape[0]):
pt1 = point(pts_coord[i,0])
pt2 = point(pts_coord[i,1])
line_list.append(line(pt1,pt2))
#Calculate the sensitivity and speed test
#Standard for loops
time_now = time.time()
for i in range(pts_coord.shape[0]):
Line_sensitivity(line_list[i])
print('For loop time = {:.2f}s'.format(time.time()-time_now))
#tree_map
time_now = time.time()
jax.tree_util.tree_map(Line_sensitivity,line_list)
print('Tree_map time = {:.2f}s'.format(time.time()-time_now)) The outputs:
As can be seen in the outputs, GY |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
I see 3 ways to speedup your code: jit, vmap, both (1) jit: You can simply f = jax.jit(lambda line_list: jax.tree_util.tree_map(Line_sensitivity,line_list))
f(line_list) # the first call will be slow
f(line_list) # all the others will be fast It will be fast but the potential problem with this approach is that the compilation time will grow with the length of the list. (2) vmap: pts_coord = np.random.rand(5000,2) #5000 lines
lines = line(point(pts_coord[:, 0]), point(pts_coord[:, 1]))
jax.vmap(Line_sensitivity)(lines) (3) both f = jax.jit(jax.vmap(Line_sensitivity)) (1) will definitely speedup your code. I am not sure that (3) is faster than (1) but I would favor the option (3) to option (1) |
Beta Was this translation helpful? Give feedback.
-
Some updates!
import time
from jax import jacfwd
from jax.tree_util import register_pytree_node_class
#A line with 2 points
class new_line():
def __init__(self,xi,xj):
self.xi = xi
self.xj = xj
#Register line
@register_pytree_node_class
class register_new_line(line):
def __repr__(self):
return "register_line(x_i={},x_j={})".format(self.xi,self.xj)
def tree_flatten(self):
children = (self.xi,self.xj)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
#A function for a line object
def New_Line_sensitivity(a_line):
xi = a_line.xi
xj = a_line.xj
#calculate the sensitivity of the length of the line
return jacfwd(L,argnums=(0,1))(xi,xj)
pts_coord = np.random.rand(5000,2) #5000 lines
lines = register_new_line(pts_coord[:, 0], pts_coord[:, 1])
#vmap time
time_now = time.time()
out = jax.vmap(New_Line_sensitivity)(lines)
print('Vmap time = {:.2f}s'.format(time.time()-time_now)) The output now is However, if I keep the old classes and register them as follows: #register this point
@register_pytree_node_class
class register_point(point):
def __repr__(self):
return "register_point(x={})".format(self.x)
def tree_flatten(self):
children = (self.x)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
@register_pytree_node_class
class register_line(line):
def __repr__(self):
return "register_line(x_i={},x_j={})".format(self.inode.x,self.jnode.x)
def tree_flatten(self):
children = (self.inode.x,self.jnode.x)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children) Doing (2) gives me an error: AttributeError Traceback (most recent call last)
[<ipython-input-30-040de9cd0525>](https://localhost:8080/#) in tree_flatten(self)
34
35 def tree_flatten(self):
---> 36 children = (self.inode.x,self.jnode.x)
37 aux_data = None
38 return (children, aux_data)
AttributeError: 'object' object has no attribute 'x' I know it's because my Thanks again :). |
Beta Was this translation helpful? Give feedback.
I see 3 ways to speedup your code: jit, vmap, both
(1) jit: You can simply
jax.jit
your functionIt will be fast but the potential problem with this approach is that the compilation time will grow with the length of the list.
(2) vmap:
jax.vmap
only work along an axis of anndarray
. So you will have to change the format of your input.(3) both
…