@@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
283283 return _argnums_partial (f , dyn_argnums , tuple (fixed_args )), dyn_args
284284
285285@lu .transformation2
286- def _argnums_partial (f , dyn_argnums , fixed_args , * dyn_args , ** kwargs ):
286+ def _argnums_partial (_fun , _dyn_argnums , _fixed_args , * dyn_args , ** kwargs ):
287287 sentinel = object ()
288- args = [sentinel ] * (len (fixed_args ) + len (dyn_args ))
289- for i , arg in zip (dyn_argnums , dyn_args ):
288+ args = [sentinel ] * (len (_fixed_args ) + len (dyn_args ))
289+ for i , arg in zip (_dyn_argnums , dyn_args ):
290290 args [i ] = arg
291- fixed_args_ = iter (fixed_args )
291+ fixed_args_ = iter (_fixed_args )
292292 args = [next (fixed_args_ ).val if x is sentinel else x for x in args ]
293293 assert next (fixed_args_ , sentinel ) is sentinel
294- return f (* args , ** kwargs )
294+ return _fun (* args , ** kwargs )
295295
296296def argnames_partial_except (f : lu .WrappedFun , static_argnames : tuple [str , ...],
297297 kwargs : dict [str , Any ]):
@@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
315315 return _argnames_partial (f , WrapKwArgs (fixed_kwargs )), dyn_kwargs
316316
317317@lu .transformation2
318- def _argnames_partial (f , fixed_kwargs : WrapKwArgs , * args , ** dyn_kwargs ):
319- kwargs = dict ({k : v .val for k , v in fixed_kwargs .val .items ()}, ** dyn_kwargs )
320- return f (* args , ** kwargs )
318+ def _argnames_partial (_fun , _fixed_kwargs : WrapKwArgs , * args , ** dyn_kwargs ):
319+ kwargs = dict ({k : v .val for k , v in _fixed_kwargs .val .items ()}, ** dyn_kwargs )
320+ return _fun (* args , ** kwargs )
321321
322322
323323@lru_cache (maxsize = 4096 )
@@ -438,9 +438,9 @@ def flat_out_axes(
438438 return f , HashableFunction (out_axes , closure = (tuple (leaves ), treedef ))
439439
440440@lu .transformation_with_aux2
441- def _flat_out_axes (f , store , leaves , treedef , * args , ** kwargs ):
442- ans = f (* args , ** kwargs )
443- spec = tree_unflatten (treedef , leaves )
441+ def _flat_out_axes (_fun , _store , _leaves , _treedef , * args , ** kwargs ):
442+ ans = _fun (* args , ** kwargs )
443+ spec = tree_unflatten (_treedef , _leaves )
444444 try :
445445 spec_flat = tuple (broadcast_prefix (spec , ans , is_leaf = lambda x : x is None ))
446446 except ValueError :
@@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
451451 "that the `out_axes` argument to `pmap` is a pytree prefix of the "
452452 "pmapped function's output." )
453453 raise ValueError (msg ) from None
454- store .store (spec_flat )
454+ _store .store (spec_flat )
455455 return ans
456456
457457def check_callable (fun ):
@@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
687687 for path , l in generate_key_paths (x ) if l is not static )
688688
689689@lu .transformation_with_aux2
690- def result_paths (f , store , * args , ** kwargs ):
690+ def result_paths (_fun , _store , * args , ** kwargs ):
691691 "linear_util transform to get output pytree paths of pre-flattened function."
692- ans = f (* args , ** kwargs )
693- store .store ([keystr (path ) for path , _ in generate_key_paths (ans )])
692+ ans = _fun (* args , ** kwargs )
693+ _store .store ([keystr (path ) for path , _ in generate_key_paths (ans )])
694694 return ans
695695
696696def jaxpr_debug_info (jaxpr : core .Jaxpr , trace_debug : TracingDebugInfo | None ,
0 commit comments