@@ -1717,10 +1717,10 @@ def _dimension_size_lowering_rule(ctx, arg, *, dimension):
1717
1717
mlir .register_lowering (dimension_size_p , _dimension_size_lowering_rule )
1718
1718
1719
1719
1720
- def all_dim_vars (args_avals : Sequence [core .AbstractValue ]) -> Sequence [str ]:
1720
+ def all_dim_vars (args_avals : Sequence [core .ShapedArray ]) -> Sequence [str ]:
1721
1721
dim_vars : set [str ] = set ()
1722
1722
for a in args_avals :
1723
- for d in a .shape : # type: ignore[attribute-error,unused-ignore]
1723
+ for d in a .shape :
1724
1724
if is_symbolic_dim (d ):
1725
1725
dim_vars = dim_vars .union (d ._get_vars ())
1726
1726
return sorted (dim_vars )
@@ -1911,7 +1911,7 @@ def pretty_print_dimension_descriptor(
1911
1911
1912
1912
@util .cache ()
1913
1913
def solve_dim_vars (
1914
- args_avals : Sequence [core .AbstractValue ],
1914
+ args_avals : Sequence [core .ShapedArray ],
1915
1915
args_kwargs_tree : tree_util .PyTreeDef ,
1916
1916
) -> tuple [DimVarEnv , ShapeConstraints , Sequence [tuple [str , int , int ]]]:
1917
1917
"""Solves dimension variables in a called function's avals in terms of actual argument shapes.
@@ -1956,12 +1956,12 @@ def solve_dim_vars(
1956
1956
# tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b'))
1957
1957
polymorphic_shape_specs : list [tuple [str , str ]] = []
1958
1958
for arg_idx , aval in enumerate (args_avals ):
1959
- if all (not is_symbolic_dim (d ) for d in aval .shape ): # type: ignore[attribute-error,unused-ignore]
1959
+ if all (not is_symbolic_dim (d ) for d in aval .shape ):
1960
1960
continue
1961
1961
polymorphic_shape_specs .append (
1962
1962
(pretty_print_dimension_descriptor (args_kwargs_tree , arg_idx , None ),
1963
- str (aval .shape ))) # type: ignore[attribute-error,unused-ignore]
1964
- for dim_idx , aval_d in enumerate (aval .shape ): # type: ignore[attribute-error,unused-ignore]
1963
+ str (aval .shape )))
1964
+ for dim_idx , aval_d in enumerate (aval .shape ):
1965
1965
if is_symbolic_dim (aval_d ):
1966
1966
synth_dim_var = pretty_print_dimension_descriptor (args_kwargs_tree ,
1967
1967
arg_idx , dim_idx )
@@ -1976,7 +1976,7 @@ def solve_dim_vars(
1976
1976
1977
1977
1978
1978
def compute_dim_vars_from_arg_shapes (
1979
- args_avals : Sequence [core .AbstractValue ],
1979
+ args_avals : Sequence [core .ShapedArray ],
1980
1980
* actual_args : jax .Array ,
1981
1981
args_kwargs_tree : tree_util .PyTreeDef ) -> Sequence [jax .Array ]:
1982
1982
"""Computes values of dimension variables to unify args_avals with actual arguments.
0 commit comments