Large Performance Gap Between Jax and Pytorch on Tree Traversal #16772
Replies: 2 comments 8 replies
-
On the JAX side, I suspect what you're running into is long compile times. Under JIT, JAX will unroll all Python loops into a linear program and pass this flat set of instructions to XLA, which can result in large programs with long compile times. I'd suggest one of the following:
Finally, I should note that your JAX benchmarks may not be accurate because of JAX's asynchronous dispatch; see FAQ: Benchmarking JAX Code for some tips on getting accurate benchmarks of JAX functions. In particular, you probably want to wrap your benchmark code in Good luck! |
Beta Was this translation helpful? Give feedback.
-
@jamespinkerton just like what @jakevdp mentioned, you can use @partial(jax.jit, static_argnums=(0,))
def __call__(self, x: jax.Array) -> jax.Array:
indexes = jnp.broadcast_to(
self.nodes_offset, (x.shape[0], self.num_trees)
).reshape(-1)
def fori_loop_body(i, indexes):
thresholds = self.select(self.thresholds, indexes)
lefts = self.select(self.lefts, indexes)
rights = self.select(self.rights, indexes)
feature_nodes = self.select(self.features, indexes)
feature_values = jnp.take_along_axis(x, feature_nodes, axis=1)
indexes = jnp.where(feature_values <= thresholds, lefts, rights).astype(
jnp.int32
)
indexes = indexes + self.nodes_offset
indexes = indexes.reshape(-1)
return indexes
indexes = jax.lax.fori_loop(0, self.max_tree_depth, fori_loop_body, indexes)
return self.select(self.values, indexes).sum(1) As a result, average torch 7.899 and average jax 7.568, and jax spends about 5 ms for each loop except first one on my m1 pro chip macbook |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi. I'm writing code to do inference on decision tree models. The ideas come from https://github.com/microsoft/hummingbird.
I'm getting 4x or so slower performance in my jax implementation and I can't figure out if it's my fault or inherently slower in jax. I'm running this on a CPU.
I'm guessing this is me writing something very slow in jax and that it's my fault. Any help would be very appreciated!
Beta Was this translation helpful? Give feedback.
All reactions