Skip to content
Discussion options

You must be logged in to vote

I think you can leverage jax.linear_transpose, and never need to struggle with how to transpose a linear function!

import jax
import jax.numpy as jnp

import numpy as np
from jax import lax
x = jax.random.uniform(jax.random.PRNGKey(0), (1, 3, 512, 512))

kernel = np.ones((4, 1, 2, 2))
kernel[1, 0, 0, 1] = -1
kernel[1, 0, 1, 1] = -1

kernel[2, 0, 1, 0] = -1
kernel[2, 0, 1, 1] = -1

kernel[3, 0, 1, 0] = -1
kernel[3, 0, 0, 1] = -1
kernel *= 0.5

kernel = np.concatenate([kernel] * 3, 0)

kerner = jnp.asarray(kernel, dtype=jnp.float32)


def fwd(x):
    dn = lax.conv_dimension_numbers(
        x.shape,     # only ndim matters, not shape
        kerner.shape,  # only ndim matters, not shape

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
2 replies
@jejjohnson
Comment options

@YouJiacheng
Comment options

Answer selected by jejjohnson
Comment options

You must be logged in to vote
2 replies
@YouJiacheng
Comment options

@jejjohnson
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants