@@ -460,7 +460,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
460
460
actual_lowering_platforms = (default_lowering_platform (),)
461
461
462
462
# TODO: move to `lower`
463
- symbolic_scope : tuple [shape_poly .SymbolicScope , tree_util .KeyPath ] | None = None
463
+ symbolic_scope : tuple [shape_poly .SymbolicScope , tree_util .KeyPath ] | None = None # type: ignore[invalid-annotation,unused-ignore]
464
464
for k_path , aval in tree_util .tree_flatten_with_path ((args_specs , kwargs_specs ))[0 ]:
465
465
# Static args may have no `shape` attribute.
466
466
if not hasattr (aval , "shape" ):
@@ -476,7 +476,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
476
476
other_descr = shape_poly .args_kwargs_path_to_str (k_path ))
477
477
478
478
if has_trace :
479
- traced = wrapped_fun_jax .trace (
479
+ traced = wrapped_fun_jax .trace ( # type: ignore
480
480
* args_specs , ** kwargs_specs ,
481
481
_experimental_lowering_parameters = mlir .LoweringParameters (
482
482
platforms = actual_lowering_platforms ,
@@ -547,7 +547,7 @@ def _export_lowered(
547
547
elif "shards" in lowering .compile_args : # for PmapComputation
548
548
out_avals_flat = lowering .compile_args ["shards" ].out_sharded_avals
549
549
else :
550
- out_avals_flat = lowered .compile_args ["out_avals" ]
550
+ out_avals_flat = lowered .compile_args ["out_avals" ] # type: ignore
551
551
552
552
# Log and then check the module.
553
553
if logging .vlog_is_on (3 ):
@@ -612,7 +612,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
612
612
in_shardings_hlo = in_shardings ,
613
613
out_shardings_hlo = out_shardings ,
614
614
nr_devices = nr_devices ,
615
- lowering_platforms = lowering ._platforms ,
615
+ lowering_platforms = lowering ._platforms , # type: ignore
616
616
ordered_effects = ordered_effects ,
617
617
unordered_effects = unordered_effects ,
618
618
disabled_safety_checks = tuple (disabled_checks ),
@@ -641,7 +641,7 @@ def _module_to_bytecode(module: ir.Module) -> bytes:
641
641
# and still have the payloads produced by `serialize_portable_artifact`
642
642
# compatible with potential consumers from the past.
643
643
target_version = hlo .get_minimum_version ()
644
- module_serialized = xla_client ._xla .mlir .serialize_portable_artifact (
644
+ module_serialized = xla_client ._xla .mlir .serialize_portable_artifact ( # type: ignore
645
645
mlir_str , target_version )
646
646
return module_serialized
647
647
@@ -688,8 +688,8 @@ def _wrap_main_func(
688
688
def is_token (typ , attrs ):
689
689
return (typ == mlir .token_type ()[0 ])
690
690
691
- orig_input_types = orig_main .type .inputs
692
- arg_attrs = list (ir .ArrayAttr (orig_main .arg_attrs ))
691
+ orig_input_types = orig_main .type .inputs # type: ignore
692
+ arg_attrs = list (ir .ArrayAttr (orig_main .arg_attrs )) # type: ignore
693
693
# The order of args: platform_index_arg, dim args, token args, array args.
694
694
nr_platform_index_args = 1 if has_platform_index_argument else 0
695
695
nr_dim_args = len (dim_vars )
@@ -711,8 +711,8 @@ def is_token(typ, attrs):
711
711
orig_input_types , [nr_platform_index_args , nr_dim_args , nr_token_args ])
712
712
713
713
# The order of results: tokens, array results
714
- orig_output_types = orig_main .type .results
715
- result_attrs = list (ir .ArrayAttr (orig_main .result_attrs ))
714
+ orig_output_types = orig_main .type .results # type: ignore
715
+ result_attrs = list (ir .ArrayAttr (orig_main .result_attrs )) # type: ignore
716
716
token_result_idxs = [i for i , (typ , attrs ) in enumerate (zip (orig_output_types ,
717
717
result_attrs ))
718
718
if is_token (typ , attrs )]
@@ -1138,6 +1138,8 @@ def _call_exported_abstract_eval(
1138
1138
assert len (in_avals ) == len (exported .in_avals ) # since the pytrees have the same structure
1139
1139
# Check that the expected shapes match the actual ones
1140
1140
for arg_idx , (exp_aval , actual_aval ) in enumerate (zip (exported .in_avals , in_avals )):
1141
+ exp_aval : core .ShapedArray = exp_aval # type: ignore
1142
+ actual_aval : core .ShapedArray = actual_aval # type: ignore
1141
1143
def pp_arg_dim (dim_idx : int | None ) -> str :
1142
1144
return shape_poly .pretty_print_dimension_descriptor (exported .in_tree ,
1143
1145
arg_idx , dim_idx )
@@ -1181,10 +1183,10 @@ def pp_arg_dim(dim_idx: int | None) -> str:
1181
1183
exported_dim_values = [synthetic_eval .evaluate (solution [var ])
1182
1184
for var in exported_dim_vars ]
1183
1185
out_avals = tuple (
1184
- core .ShapedArray (core .evaluate_shape (out_aval .shape , exported_dim_vars ,
1186
+ core .ShapedArray (core .evaluate_shape (out_aval .shape , exported_dim_vars , # type: ignore
1185
1187
* exported_dim_values ),
1186
- dtype = out_aval .dtype , weak_type = out_aval .weak_type ,
1187
- named_shape = out_aval .named_shape )
1188
+ dtype = out_aval .dtype , weak_type = out_aval .weak_type , # type: ignore
1189
+ named_shape = out_aval .named_shape ) # type: ignore
1188
1190
for out_aval in exported .out_avals )
1189
1191
return out_avals , set (exported .ordered_effects + exported .unordered_effects )
1190
1192
0 commit comments