Skip to content
Discussion options

You must be logged in to vote

I assume this is on a GPU backend, yes? On GPU, scan is often very inefficient because it requires serial computations, and thus can't take advantage of the parallelism inherent in GPU vectorized operations.

I can't think of any way to express this functionality in terms of convolution (there doesn't really seem to be any reduction involved). But I think you can express this in terms of broadcasted indices:

@functools.partial(jax.jit, static_argnums=(2,))
def overlap_merge(local_grid, global_grid, stride):
  out_shape = (
      (global_grid.shape[0] - 1) * stride[0] + local_grid.shape[0],
      (global_grid.shape[1] - 1) * stride[1] + local_grid.shape[1],
  )
  out_dtype = jnp.result_type(l…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@wztdream
Comment options

Answer selected by wztdream
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants