1717# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
1818# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919
20+ import contextlib
2021import dataclasses as dc
2122import sys
2223from typing import Any , get_args , get_origin , TypeVar , Union
2324
25+ import jax
2426import jax .numpy as jnp
2527import jax .random as jr
2628import numpy as np
@@ -554,6 +556,16 @@ def test_arraylike(typecheck, getkey):
554556 floatlike2 = Float [ArrayLike , "" ]
555557 floatlike3 = Float32 [ArrayLike , "4" ]
556558
559+ def _literal (dtype , dimstr ):
560+ out = []
561+ with contextlib .suppress (Exception ):
562+ # JAX ==0.7.2
563+ out .append (dtype [jax ._src .literals .LiteralArray , dimstr ])
564+ with contextlib .suppress (Exception ):
565+ # JAX > 0.7.2
566+ out .append (dtype [jax ._src .literals .TypedNdArray , dimstr ])
567+ return out
568+
557569 assert get_origin (floatlike1 ) is Union
558570 assert get_origin (floatlike2 ) is Union
559571 assert get_origin (floatlike3 ) is Union
@@ -564,6 +576,7 @@ def test_arraylike(typecheck, getkey):
564576 Float32 [np .number , "" ],
565577 float ,
566578 ]
579+ + _literal (Float32 , "" )
567580 )
568581 assert _to_set (get_args (floatlike2 )) == _to_set (
569582 [
@@ -572,12 +585,14 @@ def test_arraylike(typecheck, getkey):
572585 Float [np .number , "" ],
573586 float ,
574587 ]
588+ + _literal (Float , "" )
575589 )
576590 assert _to_set (get_args (floatlike3 )) == _to_set (
577591 [
578592 Float32 [Array , "4" ],
579593 Float32 [np .ndarray , "4" ],
580594 ]
595+ + _literal (Float32 , "4" )
581596 )
582597
583598 shaped1 = Shaped [ArrayLike , "" ]
@@ -595,12 +610,14 @@ def test_arraylike(typecheck, getkey):
595610 float ,
596611 complex ,
597612 ]
613+ + _literal (Shaped , "" )
598614 )
599615 assert _to_set (get_args (shaped2 )) == _to_set (
600616 [
601617 Shaped [Array , "4" ],
602618 Shaped [np .ndarray , "4" ],
603619 ]
620+ + _literal (Shaped , "4" )
604621 )
605622
606623
0 commit comments