Skip to content

Commit dd2394e

Browse files
mattjjpatrick-kidger
authored andcommitted
add uint2/int2
following JAX's uint2/int2 data types: https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py#L182
1 parent 682787f commit dd2394e

File tree

6 files changed

+20
-4
lines changed

6 files changed

+20
-4
lines changed

docs/api/array.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ The dtype should be any one of (all imported from `jaxtyping`):
6666
- Of particular precision: `Complex64`, `Complex128`
6767
- Any integer or unsigned intger: `Integer`
6868
- Any unsigned integer: `UInt`
69-
- Of particular precision: `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
69+
- Of particular precision: `UInt2`, `UInt4`, `UInt8`, `UInt16`, `UInt32`, `UInt64`
7070
- Any signed integer: `Int`
71-
- Of particular precision: `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
71+
- Of particular precision: `Int2`, `Int4`, `Int8`, `Int16`, `Int32`, `Int64`
7272
- Any floating, integer, or unsigned integer: `Real`.
7373

7474
Unless you really want to force a particular precision, then for most applications you should probably allow any floating-point, any integer, etc. That is, use

jaxtyping/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
Float64 as Float64,
6767
Inexact as Inexact,
6868
Int as Int,
69+
Int2 as Int2,
6970
Int4 as Int4,
7071
Int8 as Int8,
7172
Int16 as Int16,
@@ -80,6 +81,7 @@
8081
ScalarLike as ScalarLike,
8182
Shaped as Shaped,
8283
UInt as UInt,
84+
UInt2 as UInt2,
8385
UInt4 as UInt4,
8486
UInt8 as UInt8,
8587
UInt16 as UInt16,
@@ -123,6 +125,7 @@
123125
Float64 as Float64,
124126
Inexact as Inexact,
125127
Int as Int,
128+
Int2 as Int2,
126129
Int4 as Int4,
127130
Int8 as Int8,
128131
Int16 as Int16,
@@ -134,6 +137,7 @@
134137
Real as Real,
135138
Shaped as Shaped,
136139
UInt as UInt,
140+
UInt2 as UInt2,
137141
UInt4 as UInt4,
138142
UInt8 as UInt8,
139143
UInt16 as UInt16,

jaxtyping/_array_types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,11 +735,13 @@ def __init_subclass__(cls, **kwargs):
735735
_prng_key = "prng_key"
736736
_bool = "bool"
737737
_bool_ = "bool_"
738+
_uint2 = "uint2"
738739
_uint4 = "uint4"
739740
_uint8 = "uint8"
740741
_uint16 = "uint16"
741742
_uint32 = "uint32"
742743
_uint64 = "uint64"
744+
_int2 = "int2"
743745
_int4 = "int4"
744746
_int8 = "int8"
745747
_int16 = "int16"
@@ -772,11 +774,13 @@ class _Cls(AbstractDtype):
772774
return _Cls
773775

774776

777+
UInt2 = _make_dtype(_uint2, "UInt2")
775778
UInt4 = _make_dtype(_uint4, "UInt4")
776779
UInt8 = _make_dtype(_uint8, "UInt8")
777780
UInt16 = _make_dtype(_uint16, "UInt16")
778781
UInt32 = _make_dtype(_uint32, "UInt32")
779782
UInt64 = _make_dtype(_uint64, "UInt64")
783+
Int2 = _make_dtype(_int2, "Int2")
780784
Int4 = _make_dtype(_int4, "Int4")
781785
Int8 = _make_dtype(_int8, "Int8")
782786
Int16 = _make_dtype(_int16, "Int16")
@@ -795,8 +799,8 @@ class _Cls(AbstractDtype):
795799
Complex128 = _make_dtype(_complex128, "Complex128")
796800

797801
bools = [_bool, _bool_]
798-
uints = [_uint4, _uint8, _uint16, _uint32, _uint64]
799-
ints = [_int4, _int8, _int16, _int32, _int64]
802+
uints = [_uint2, _uint4, _uint8, _uint16, _uint32, _uint64]
803+
ints = [_int2, _int4, _int8, _int16, _int32, _int64]
800804
float8 = [
801805
_float8_e4m3b11fnuz,
802806
_float8_e4m3fn,

jaxtyping/_indirection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Annotated as Float64, # noqa: F401
3838
Annotated as Inexact, # noqa: F401
3939
Annotated as Int, # noqa: F401
40+
Annotated as Int2, # noqa: F401
4041
Annotated as Int4, # noqa: F401
4142
Annotated as Int8, # noqa: F401
4243
Annotated as Int16, # noqa: F401
@@ -48,6 +49,7 @@
4849
Annotated as Real, # noqa: F401
4950
Annotated as Shaped, # noqa: F401
5051
Annotated as UInt, # noqa: F401
52+
Annotated as UInt2, # noqa: F401
5153
Annotated as UInt4, # noqa: F401
5254
Annotated as UInt8, # noqa: F401
5355
Annotated as UInt16, # noqa: F401

test/test_all_importable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ def test_all_importable():
1818
Complex128, # noqa: F401
1919
Integer, # noqa: F401
2020
UInt, # noqa: F401
21+
UInt2, # noqa: F401
2122
UInt4, # noqa: F401
2223
UInt8, # noqa: F401
2324
UInt16, # noqa: F401
2425
UInt32, # noqa: F401
2526
UInt64, # noqa: F401
2627
Int, # noqa: F401
28+
Int2, # noqa: F401
2729
Int4, # noqa: F401
2830
Int8, # noqa: F401
2931
Int16, # noqa: F401

test/test_array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def test_dtypes():
7878
Float64,
7979
Inexact,
8080
Int,
81+
Int2,
8182
Int4,
8283
Int8,
8384
Int16,
@@ -86,6 +87,7 @@ def test_dtypes():
8687
Num,
8788
Shaped,
8889
UInt,
90+
UInt2,
8991
UInt4,
9092
UInt8,
9193
UInt16,
@@ -156,8 +158,10 @@ def g(x: Shaped[Array, "a b"]) -> Shaped[Array, "a b"]:
156158

157159
g(jr.normal(getkey(), (3, 4)))
158160
g(jnp.array([[True, False]]))
161+
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int2))
159162
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int4))
160163
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.int8))
164+
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint2))
161165
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint4))
162166
g(jnp.array([[1, 2], [3, 4]], dtype=jnp.uint16))
163167
g(jr.normal(getkey(), (3, 4), dtype=jnp.complex64))

0 commit comments

Comments
 (0)