@@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
6868 else :
6969 return tuple (map (_ensure_str , x ))
7070
71- @lu .transformation_with_aux
72- def flatten_fun (in_tree , * args_flat ):
71+ @lu .transformation_with_aux2
72+ def flatten_fun (f , store , in_tree , * args_flat ):
7373 py_args , py_kwargs = tree_unflatten (in_tree , args_flat )
74- ans = yield py_args , py_kwargs
75- yield tree_flatten (ans )
74+ ans = f (* py_args , ** py_kwargs )
75+ ans , out_tree = tree_flatten (ans )
76+ store .store (out_tree )
77+ return ans
7678
7779def apply_flat_fun (fun , io_tree , * py_args ):
7880 in_tree_expected , out_tree = io_tree
@@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args):
8284 ans = fun (* args )
8385 return tree_unflatten (out_tree , ans )
8486
85- @lu .transformation_with_aux
86- def flatten_fun_nokwargs (in_tree , * args_flat ):
87+ @lu .transformation_with_aux2
88+ def flatten_fun_nokwargs (f , store , in_tree , * args_flat ):
8789 py_args = tree_unflatten (in_tree , args_flat )
88- ans = yield py_args , {}
89- yield tree_flatten (ans )
90+ ans = f (* py_args )
91+ ans , out_tree = tree_flatten (ans )
92+ store .store (out_tree )
93+ return ans
9094
9195def apply_flat_fun_nokwargs (fun , io_tree , py_args ):
9296 in_tree_expected , out_tree = io_tree
@@ -118,17 +122,18 @@ def flattened_fun_in_tree(
118122 else :
119123 return in_tree , lambda : out_tree_store .val , has_kwargs
120124
121- @lu .transformation_with_aux
122- def flatten_fun_nokwargs2 (in_tree , * args_flat ):
125+ @lu .transformation_with_aux2
126+ def flatten_fun_nokwargs2 (f , store , in_tree , * args_flat ):
123127 py_args = tree_unflatten (in_tree , args_flat )
124- pair = yield py_args , {}
128+ pair = f ( * py_args )
125129 if not isinstance (pair , (list , tuple )) or len (pair ) != 2 :
126130 raise TypeError ("expected function with aux output to return a two-element "
127131 f"tuple, but got type { type (pair )} with value { pair !r} " )
128132 ans , aux = pair
129133 ans_flat , ans_tree = tree_flatten (ans )
130134 aux_flat , aux_tree = tree_flatten (aux )
131- yield (ans_flat , aux_flat ), (ans_tree , aux_tree )
135+ store .store ((ans_tree , aux_tree ))
136+ return ans_flat , aux_flat
132137
133138class _HashableWithStrictTypeEquality :
134139 """Box object used when comparing static arguments as a jit key.
@@ -277,18 +282,16 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
277282
278283 return _argnums_partial (f , dyn_argnums , tuple (fixed_args )), dyn_args
279284
280- @lu .transformation
281- def _argnums_partial (dyn_argnums , fixed_args , * dyn_args , ** kwargs ):
285+ @lu .transformation2
286+ def _argnums_partial (f , dyn_argnums , fixed_args , * dyn_args , ** kwargs ):
282287 sentinel = object ()
283288 args = [sentinel ] * (len (fixed_args ) + len (dyn_args ))
284289 for i , arg in zip (dyn_argnums , dyn_args ):
285290 args [i ] = arg
286291 fixed_args_ = iter (fixed_args )
287292 args = [next (fixed_args_ ).val if x is sentinel else x for x in args ]
288293 assert next (fixed_args_ , sentinel ) is sentinel
289- ans = yield args , kwargs
290- yield ans
291-
294+ return f (* args , ** kwargs )
292295
293296def argnames_partial_except (f : lu .WrappedFun , static_argnames : tuple [str , ...],
294297 kwargs : dict [str , Any ]):
@@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
311314
312315 return _argnames_partial (f , WrapKwArgs (fixed_kwargs )), dyn_kwargs
313316
314- @lu .transformation
315- def _argnames_partial (fixed_kwargs : WrapKwArgs , * args , ** dyn_kwargs ):
317+ @lu .transformation2
318+ def _argnames_partial (f , fixed_kwargs : WrapKwArgs , * args , ** dyn_kwargs ):
316319 kwargs = dict ({k : v .val for k , v in fixed_kwargs .val .items ()}, ** dyn_kwargs )
317- ans = yield args , kwargs
318- yield ans
320+ return f (* args , ** kwargs )
319321
320322
321323@lru_cache (maxsize = 4096 )
@@ -435,9 +437,9 @@ def flat_out_axes(
435437 f , out_axes = _flat_out_axes (f , tuple (leaves ), treedef )
436438 return f , HashableFunction (out_axes , closure = (tuple (leaves ), treedef ))
437439
438- @lu .transformation_with_aux
439- def _flat_out_axes (leaves , treedef , * args , ** kwargs ):
440- ans = yield args , kwargs
440+ @lu .transformation_with_aux2
441+ def _flat_out_axes (f , store , leaves , treedef , * args , ** kwargs ):
442+ ans = f ( * args , ** kwargs )
441443 spec = tree_unflatten (treedef , leaves )
442444 try :
443445 spec_flat = tuple (broadcast_prefix (spec , ans , is_leaf = lambda x : x is None ))
@@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs):
449451 "that the `out_axes` argument to `pmap` is a pytree prefix of the "
450452 "pmapped function's output." )
451453 raise ValueError (msg ) from None
452- yield ans , spec_flat
454+ store .store (spec_flat )
455+ return ans
453456
454457def check_callable (fun ):
455458 # In Python 3.10+, the only thing stopping us from supporting staticmethods
@@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
683686 return tuple (f'{ name } { keystr (path )} ' for name , x in ba .arguments .items ()
684687 for path , l in generate_key_paths (x ) if l is not static )
685688
686- @lu .transformation_with_aux
687- def result_paths (* args , ** kwargs ):
689+ @lu .transformation_with_aux2
690+ def result_paths (f , store , * args , ** kwargs ):
688691 "linear_util transform to get output pytree paths of pre-flattened function."
689- ans = yield args , kwargs
690- yield ans , [keystr (path ) for path , _ in generate_key_paths (ans )]
692+ ans = f (* args , ** kwargs )
693+ store .store ([keystr (path ) for path , _ in generate_key_paths (ans )])
694+ return ans
691695
692696def jaxpr_debug_info (jaxpr : core .Jaxpr , trace_debug : TracingDebugInfo | None ,
693697 result_paths : tuple [str , ...] | None = None ,
0 commit comments