@@ -171,7 +171,7 @@ class PjitInfo(NamedTuple):
171
171
172
172
173
173
def _python_pjit_helper (jit_info , * args , ** kwargs ):
174
- (args_flat , params , _ , out_tree , _ , arg_names ,
174
+ (args_flat , params , in_avals , _ , out_tree , _ , arg_names , _ ,
175
175
attrs_tracked ) = _infer_params (jit_info , args , kwargs )
176
176
177
177
for arg in args_flat :
@@ -197,7 +197,7 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
197
197
if params ['jaxpr' ].consts :
198
198
raise TypeError (e .args [0 ]) from e
199
199
else :
200
- for arg , name , aval in zip (args_flat , arg_names , params [ 'jaxpr' ]. in_avals ):
200
+ for arg , name , aval in zip (args_flat , arg_names , in_avals ):
201
201
try :
202
202
xla .canonicalize_dtype (arg )
203
203
except xla .InvalidInputException as _ :
@@ -491,7 +491,7 @@ def lower(*args, **kwargs):
491
491
492
492
@api_boundary
493
493
def eval_shape (* args , ** kwargs ):
494
- _ , params , _ , out_tree , _ , _ , _ = _infer_params (jit_info , args , kwargs )
494
+ _ , params , _ , _ , out_tree , _ , _ , _ , _ = _infer_params (jit_info , args , kwargs )
495
495
out_s = [None if is_unspecified (s ) else s for s in params ['out_shardings' ]]
496
496
# TODO(yashkatariya): Add `Layout` to SDS.
497
497
out = [api .ShapeDtypeStruct (x .shape , x .dtype , x .named_shape , sharding = s )
@@ -503,16 +503,15 @@ def trace(*args, **kwargs) -> stages.Traced:
503
503
lowering_parameters = kwargs .pop (
504
504
'_experimental_lowering_parameters' , mlir .LoweringParameters ())
505
505
506
- (args_flat , params , in_tree , out_tree , donated_invars ,
507
- arg_names , _ ) = _infer_params (jit_info , args , kwargs )
506
+ (args_flat , params , in_avals , in_tree , out_tree , donated_invars ,
507
+ arg_names , num_consts , _ ) = _infer_params (jit_info , args , kwargs )
508
508
509
509
donate_argnums = tuple (i for i , d in enumerate (donated_invars ) if d )
510
- jaxpr = params ['jaxpr' ]
511
- args_info = stages .make_args_info (in_tree , jaxpr .in_avals , donate_argnums )
510
+ args_info = stages .make_args_info (in_tree , in_avals , donate_argnums )
512
511
lower_callable = partial (_resolve_and_lower , args_flat , ** params ,
513
512
lowering_parameters = lowering_parameters )
514
- return stages .Traced (jaxpr , args_info , params ["name" ], out_tree ,
515
- lower_callable , args_flat , arg_names )
513
+ return stages .Traced (params [ ' jaxpr' ] , args_info , params ["name" ], out_tree ,
514
+ lower_callable , args_flat , arg_names , num_consts )
516
515
517
516
wrapped = _cpp_pjit (jit_info )
518
517
wrapped .lower = lower
@@ -662,8 +661,9 @@ def _infer_params(jit_info, args, kwargs):
662
661
keep_unused = keep_unused ,
663
662
inline = inline ,
664
663
)
665
- return (consts + args_flat , params , in_tree , out_tree (),
666
- donated_invars , dbg .arg_names if dbg else None , attrs_tracked )
664
+ return (consts + args_flat , params , in_avals , in_tree , out_tree (),
665
+ donated_invars , dbg .arg_names if dbg else None , len (consts ),
666
+ attrs_tracked )
667
667
668
668
def _extract_implicit_args (
669
669
in_type : Sequence [tuple [core .AbstractValue , bool ]],
0 commit comments