@@ -171,7 +171,7 @@ class PjitInfo(NamedTuple):
171
171
172
172
173
173
def _python_pjit_helper (jit_info , * args , ** kwargs ):
174
- (args_flat , _ , params , _ , out_tree , _ , arg_names ,
174
+ (args_flat , params , _ , out_tree , _ , arg_names ,
175
175
attrs_tracked ) = _infer_params (jit_info , args , kwargs )
176
176
177
177
for arg in args_flat :
@@ -480,7 +480,7 @@ def lower(*args, **kwargs):
480
480
lowering_parameters = kwargs .pop (
481
481
'_experimental_lowering_parameters' , mlir .LoweringParameters ())
482
482
483
- (args_flat , flat_global_in_avals , params , in_tree , out_tree ,
483
+ (args_flat , params , in_tree , out_tree ,
484
484
donated_invars , arg_names , _ ) = _infer_params (jit_info , args , kwargs )
485
485
try :
486
486
lowering = _resolve_and_lower (
@@ -496,13 +496,14 @@ def lower(*args, **kwargs):
496
496
raise ValueError (msg ) from None
497
497
498
498
donate_argnums = tuple (i for i , d in enumerate (donated_invars ) if d )
499
+ jaxpr = params ["jaxpr" ]
499
500
return stages .Lowered .from_flat_info (
500
- lowering , in_tree , flat_global_in_avals , donate_argnums ,
501
- out_tree , fun_name = params ["name" ], jaxpr = params [ " jaxpr" ] )
501
+ lowering , in_tree , jaxpr . in_avals , donate_argnums , out_tree ,
502
+ fun_name = params ["name" ], jaxpr = jaxpr )
502
503
503
504
@api_boundary
504
505
def eval_shape (* args , ** kwargs ):
505
- _ , _ , params , _ , out_tree , _ , _ , _ = _infer_params (jit_info , args , kwargs )
506
+ _ , params , _ , out_tree , _ , _ , _ = _infer_params (jit_info , args , kwargs )
506
507
out_s = [None if is_unspecified (s ) else s for s in params ['out_shardings' ]]
507
508
# TODO(yashkatariya): Add `Layout` to SDS.
508
509
out = [api .ShapeDtypeStruct (x .shape , x .dtype , x .named_shape , sharding = s )
@@ -511,12 +512,19 @@ def eval_shape(*args, **kwargs):
511
512
512
513
@api_boundary
513
514
def specialize (* args , ** kwargs ) -> stages .Specialized :
514
- _ , _ , params , in_tree , out_tree , donated_invars , _ , _ = _infer_params (
515
+ lowering_parameters = kwargs .pop (
516
+ '_experimental_lowering_parameters' , mlir .LoweringParameters ())
517
+
518
+ args_flat , params , in_tree , out_tree , donated_invars , _ , _ = _infer_params (
515
519
jit_info , args , kwargs )
520
+
516
521
donate_argnums = tuple (i for i , d in enumerate (donated_invars ) if d )
517
522
jaxpr = params ['jaxpr' ]
518
523
args_info = stages .make_args_info (in_tree , jaxpr .in_avals , donate_argnums )
519
- return stages .Specialized (jaxpr , args_info , out_tree )
524
+ lower_callable = partial (_resolve_and_lower , args_flat , ** params ,
525
+ lowering_parameters = lowering_parameters )
526
+ return stages .Specialized (jaxpr , args_info , params ["name" ], out_tree ,
527
+ lower_callable )
520
528
521
529
wrapped = _cpp_pjit (jit_info )
522
530
wrapped .lower = lower
@@ -667,7 +675,7 @@ def _infer_params(jit_info, args, kwargs):
667
675
keep_unused = keep_unused ,
668
676
inline = inline ,
669
677
)
670
- return (consts + args_flat , in_type , params , in_tree , out_tree (),
678
+ return (consts + args_flat , params , in_tree , out_tree (),
671
679
donated_invars , dbg .arg_names if dbg else None , attrs_tracked )
672
680
673
681
def _extract_implicit_args (
0 commit comments