Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/hijax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,66 @@ def f(x):
self.assertAllClose(jax.vmap(f)(xs), xs**3)
self.assertEqual(jax.grad(f)(2.0), 12.0)

@config.numpy_dtype_promotion('standard')
def test_newstyle_hiprimitive_qarray(self):

@dataclass(frozen=True) # not NamedTuple, which is a pytree
class QArray:
qvalue: jax.Array
scale: jax.Array

@dataclass(frozen=True)
class QArrayTy(HiType):
shape: tuple[int, int]

def to_tangent_aval(self):
return ShapedArray(self.shape, jnp.dtype('float32'))

register_hitype(QArray, lambda q: QArrayTy(q.qvalue.shape))

def q(x):
return Q(jax.typeof(x))(x)

def dq(qx):
return DQ(jax.typeof(qx))(qx)

class Q(NewstyleHiPrimitive):
def __init__(self, unquantized_aval):
if unquantized_aval.dtype != jnp.dtype('float32'): raise TypeError
quantized_aval = QArrayTy(unquantized_aval.shape)
super().__init__((unquantized_aval,), quantized_aval)

def expand(self, x):
scale = jnp.max(jnp.abs(x)) / 127
qvalue = jnp.round(x / scale).astype(jnp.int8)
return QArray(qvalue, scale)

def vjp_fwd(self, x):
return self(x), None

def vjp_bwd(self, _, g):
return g,

class DQ(NewstyleHiPrimitive):
def __init__(self, quantized_aval):
unquantized_aval = ShapedArray(quantized_aval.shape, jnp.dtype('float32'))
super().__init__((quantized_aval,), unquantized_aval)

def expand(self, qx):
return qx.qvalue * qx.scale

def vjp_fwd(self, qx):
return self(qx), None

def vjp_bwd(self, _, g):
return g,

def f(x):
return jnp.sum(dq(q(x)))

x = jax.random.normal(jax.random.key(0), (3, 3), dtype='float32')
g = jax.grad(f)(x)


class BoxTest(jtu.JaxTestCase):

Expand Down
Loading