@@ -779,7 +779,7 @@ def _extract_implicit_args(
779779 args [d1 .val ] = d2
780780 assert core .same_referent (args [d1 .val ], d2 )
781781 assert all (x is not None for x in args )
782- return [x for x , (_ , e ) in zip (args , in_type ) if not e ] # pytype: disable=bad-return-type
782+ return [x for x , (_ , e ) in zip (args , in_type ) if not e ] # type: ignore
783783
784784def _flat_axes_specs (abstracted_axes , * args , ** kwargs
785785 ) -> list [pe .AbstractedAxesSpec ] | None :
@@ -1545,6 +1545,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
15451545 else :
15461546 resolved_in_shardings .append (arg_s )
15471547 else :
1548+ assert isinstance (arg_s , sharding .Sharding )
15481549 if dispatch .is_single_device_sharding (arg_s ):
15491550 resolved_in_shardings .append (UNSPECIFIED )
15501551 else :
@@ -1581,7 +1582,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
15811582 not isinstance (arg_s , PmapSharding ) and
15821583 not op_shardings .are_op_shardings_equal (
15831584 pjit_in_s ._to_xla_hlo_sharding (arg .ndim ), # type: ignore
1584- arg_s ._to_xla_hlo_sharding (arg .ndim ))):
1585+ arg_s ._to_xla_hlo_sharding (arg .ndim ))): # type: ignore
15851586 raise ValueError ('Sharding passed to pjit does not match the sharding '
15861587 'on the respective arg. '
15871588 f'Got pjit sharding: { pjit_in_s } ,\n '
0 commit comments