Skip to content
Discussion options

You must be logged in to vote

Perhaps something like this will work?

import jax.numpy as jnp
from jax import lax
import numpy as np


def normpool(x):
  norms = jnp.linalg.norm(x, axis=-1)
  idxs = jnp.arange(x.shape[0])

  def g(a, b):
    an, ai = a
    bn, bi = b
    which = an >= bn
    return (jnp.where(which, an, bn), jnp.where(which, ai, bi))

  _, idxs = lax.reduce_window((norms, idxs), (-np.inf, -1), g,
                    window_dimensions=(2,), window_strides=(2,),
                    padding=((0, 0),))
  return x[idxs]


input = jnp.array([
  [1.0, 0.0, 1.0],
  [2.0, 2.0, 0.0],
  [3.0, 0.0, 1.0],
  [0.0, 1.0, 1.0],
])
print(normpool(input))
# # In each window, pick the vector with the largest norm
# output…

Replies: 1 comment 9 replies

Comment options

You must be logged in to vote
9 replies
@mattjj
Comment options

@mariogeiger
Comment options

@hawkinsp
Comment options

@YouJiacheng
Comment options

@mariogeiger
Comment options

Answer selected by mariogeiger
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants