diff --git a/tests/hijax_test.py b/tests/hijax_test.py index b0cbcf4541b4..311c0f069fa2 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -535,6 +535,66 @@ def f(x): self.assertAllClose(jax.vmap(f)(xs), xs**3) self.assertEqual(jax.grad(f)(2.0), 12.0) + @config.numpy_dtype_promotion('standard') + def test_newstyle_hiprimitive_qarray(self): + + @dataclass(frozen=True) # not NamedTuple, which is a pytree + class QArray: + qvalue: jax.Array + scale: jax.Array + + @dataclass(frozen=True) + class QArrayTy(HiType): + shape: tuple[int, int] + + def to_tangent_aval(self): + return ShapedArray(self.shape, jnp.dtype('float32')) + + register_hitype(QArray, lambda q: QArrayTy(q.qvalue.shape)) + + def q(x): + return Q(jax.typeof(x))(x) + + def dq(qx): + return DQ(jax.typeof(qx))(qx) + + class Q(NewstyleHiPrimitive): + def __init__(self, unquantized_aval): + if unquantized_aval.dtype != jnp.dtype('float32'): raise TypeError + quantized_aval = QArrayTy(unquantized_aval.shape) + super().__init__((unquantized_aval,), quantized_aval) + + def expand(self, x): + scale = jnp.max(jnp.abs(x)) / 127 + qvalue = jnp.round(x / scale).astype(jnp.int8) + return QArray(qvalue, scale) + + def vjp_fwd(self, x): + return self(x), None + + def vjp_bwd(self, _, g): + return g, + + class DQ(NewstyleHiPrimitive): + def __init__(self, quantized_aval): + unquantized_aval = ShapedArray(quantized_aval.shape, jnp.dtype('float32')) + super().__init__((quantized_aval,), unquantized_aval) + + def expand(self, qx): + return qx.qvalue * qx.scale + + def vjp_fwd(self, qx): + return self(qx), None + + def vjp_bwd(self, _, g): + return g, + + def f(x): + return jnp.sum(dq(q(x))) + + x = jax.random.normal(jax.random.key(0), (3, 3), dtype='float32') + g = jax.grad(f)(x) + class BoxTest(jtu.JaxTestCase):