Skip to content

Commit 55d0f5e

Browse files
yashk2810jax authors
authored andcommitted
Add lower to specialize making it a true Stage.
So now users can do: ``` specialized = jax.jit(f).specialize(*args) print(specialized.jaxpr, specialized.out_info) lowered = specialized.lower() compiled = lowered.compile() ``` PiperOrigin-RevId: 640737396
1 parent d117305 commit 55d0f5e

File tree

3 files changed

+44
-10
lines changed

3 files changed

+44
-10
lines changed

jax/_src/pjit.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class PjitInfo(NamedTuple):
171171

172172

173173
def _python_pjit_helper(jit_info, *args, **kwargs):
174-
(args_flat, _, params, _, out_tree, _, arg_names,
174+
(args_flat, params, _, out_tree, _, arg_names,
175175
attrs_tracked) = _infer_params(jit_info, args, kwargs)
176176

177177
for arg in args_flat:
@@ -480,7 +480,7 @@ def lower(*args, **kwargs):
480480
lowering_parameters = kwargs.pop(
481481
'_experimental_lowering_parameters', mlir.LoweringParameters())
482482

483-
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
483+
(args_flat, params, in_tree, out_tree,
484484
donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs)
485485
try:
486486
lowering = _resolve_and_lower(
@@ -496,13 +496,14 @@ def lower(*args, **kwargs):
496496
raise ValueError(msg) from None
497497

498498
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
499+
jaxpr = params["jaxpr"]
499500
return stages.Lowered.from_flat_info(
500-
lowering, in_tree, flat_global_in_avals, donate_argnums,
501-
out_tree, fun_name=params["name"], jaxpr=params["jaxpr"])
501+
lowering, in_tree, jaxpr.in_avals, donate_argnums, out_tree,
502+
fun_name=params["name"], jaxpr=jaxpr)
502503

503504
@api_boundary
504505
def eval_shape(*args, **kwargs):
505-
_, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
506+
_, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
506507
out_s = [None if is_unspecified(s) else s for s in params['out_shardings']]
507508
# TODO(yashkatariya): Add `Layout` to SDS.
508509
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
@@ -511,12 +512,19 @@ def eval_shape(*args, **kwargs):
511512

512513
@api_boundary
513514
def specialize(*args, **kwargs) -> stages.Specialized:
514-
_, _, params, in_tree, out_tree, donated_invars, _, _ = _infer_params(
515+
lowering_parameters = kwargs.pop(
516+
'_experimental_lowering_parameters', mlir.LoweringParameters())
517+
518+
args_flat, params, in_tree, out_tree, donated_invars, _, _ = _infer_params(
515519
jit_info, args, kwargs)
520+
516521
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
517522
jaxpr = params['jaxpr']
518523
args_info = stages.make_args_info(in_tree, jaxpr.in_avals, donate_argnums)
519-
return stages.Specialized(jaxpr, args_info, out_tree)
524+
lower_callable = partial(_resolve_and_lower, args_flat, **params,
525+
lowering_parameters=lowering_parameters)
526+
return stages.Specialized(jaxpr, args_info, params["name"], out_tree,
527+
lower_callable)
520528

521529
wrapped = _cpp_pjit(jit_info)
522530
wrapped.lower = lower
@@ -667,7 +675,7 @@ def _infer_params(jit_info, args, kwargs):
667675
keep_unused=keep_unused,
668676
inline=inline,
669677
)
670-
return (consts + args_flat, in_type, params, in_tree, out_tree(),
678+
return (consts + args_flat, params, in_tree, out_tree(),
671679
donated_invars, dbg.arg_names if dbg else None, attrs_tracked)
672680

673681
def _extract_implicit_args(

jax/_src/stages.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,18 +426,25 @@ class CompiledCallParams(NamedTuple):
426426

427427

428428
class Specialized(Stage):
429-
__slots__ = ["jaxpr", "args_info", "_out_tree"]
429+
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable"]
430430

431-
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, out_tree):
431+
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
432+
lower_callable):
432433
self.jaxpr = jaxpr
433434
self.args_info = args_info
435+
self.fun_name = fun_name
434436
self._out_tree = out_tree
437+
self._lower_callable = lower_callable
435438

436439
@property
437440
def out_info(self):
438441
return self._out_tree.unflatten(
439442
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])
440443

444+
def lower(self):
445+
lowering = self._lower_callable()
446+
return Lowered(lowering, self.args_info, self._out_tree)
447+
441448

442449
class Compiled(Stage):
443450
"""Compiled representation of a function specialized to types/values.

tests/pjit_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4191,6 +4191,25 @@ def f(x):
41914191
self.assertLen(specialized.in_avals[0], 1)
41924192
self.assertLen(specialized.in_avals[1], 0) # empty kwarg
41934193

4194+
def test_jit_specialize_lower_and_compile(self):
4195+
def f(x):
4196+
return x * 2
4197+
4198+
lowered = jax.jit(f).specialize(jnp.arange(8)).lower()
4199+
self.assertEqual(lowered.args_info[0][0].shape, (8,))
4200+
4201+
compiled = lowered.compile()
4202+
out = compiled(jnp.arange(8))
4203+
self.assertArraysEqual(out, np.arange(8) * 2)
4204+
4205+
# fast-forward
4206+
lowered2 = jax.jit(f).lower(jnp.arange(8))
4207+
self.assertEqual(lowered2.args_info[0][0].shape, (8,))
4208+
4209+
compiled2 = lowered2.compile()
4210+
out2 = compiled2(jnp.arange(8))
4211+
self.assertArraysEqual(out2, np.arange(8) * 2)
4212+
41944213

41954214
def spec_regex(s):
41964215
return str(s).replace(r"(", r"\(").replace(r")", r"\)")

0 commit comments

Comments
 (0)