Skip to content

Commit 81f8e65

Browse files
Merge pull request #229 from ROCm/ci-upstream-sync-112_1
CI: 02/11/25 upstream sync
2 parents 073bc92 + 3745591 commit 81f8e65

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+1789
-1576
lines changed

.bazelrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc
233233
build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
234234

235235
# Mac x86 CI configs
236-
build:ci_darwin_x86_64 --macos_minimum_os=10.14
236+
build:ci_darwin_x86_64 --macos_minimum_os=11.0
237237
build:ci_darwin_x86_64 --config=macos_cache_push
238238
build:ci_darwin_x86_64 --verbose_failures=true
239239
build:ci_darwin_x86_64 --color=yes

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
python-version: ["3.10"]
3434
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
3535
env:
36-
LIBTPU_OLDEST_VERSION_DATE: 20240922
36+
LIBTPU_OLDEST_VERSION_DATE: 20241118
3737
PYTHON: python${{ matrix.python-version }}
3838
runs-on: ${{ matrix.tpu.runner }}
3939
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"

build/build.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,8 @@ async def main():
588588
)
589589
for option in args.bazel_options:
590590
wheel_build_command_base.append(option)
591+
if "cuda" in args.wheels:
592+
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
591593

592594
with open(".jax_configure.bazelrc", "w") as f:
593595
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list())

docs/pallas/grid_blockspec.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ If the block shape does not divide evenly the overall shape then the
8888
last iteration on each axis will still receive references to blocks
8989
of `block_shape` but the elements that are out-of-bounds are padded
9090
on input and discarded on output. The values of the padding are unspecified, and
91-
you should assume they is garbage. In the `interpret=True` mode, we
91+
you should assume they are garbage. In the `interpret=True` mode, we
9292
pad with NaN for floating-point values, to give users a chance to
9393
spot accessing out-of-bounds elements, but this behavior should not
9494
be depended upon. Note that at least one of the

jax/_src/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,10 +2060,9 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
20602060
shape/dtypes/structure as ``primals``.
20612061
20622062
>>> import jax
2063-
>>> import types
20642063
>>>
20652064
>>> f = lambda x, y: 0.5 * x - 0.5 * y
2066-
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
2065+
>>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32))
20672066
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
20682067
>>> f_transpose(1.0)
20692068
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))

jax/_src/api_util.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from jax._src.state.types import AbstractRef
2828
from jax._src.tree_util import (
2929
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
30-
treedef_children, generate_key_paths, keystr, broadcast_prefix,
30+
treedef_children, generate_key_paths, broadcast_prefix,
3131
prefix_errors)
3232
from jax._src.tree_util import _replace_nones
3333
from jax._src import linear_util as lu
@@ -595,6 +595,15 @@ def debug_info(
595595
sourceinfo: str | None = None,
596596
signature: inspect.Signature | None = None,
597597
) -> core.DebugInfo:
598+
"""Constructd core.DebugInfo for a function given example args and kwargs.
599+
600+
`args` and `kwargs` are example positional and keyword arguments, users with
601+
`inspect.Signature` to get the names of argments. The arguments that are
602+
considered static for tracing purposes should be included, and designated
603+
using `static_argnums` and `static_argnames`.
604+
605+
See docstring for linear_util.DebugInfo.
606+
"""
598607
if sourceinfo is None:
599608
sourceinfo = fun_sourceinfo(fun)
600609
if signature is None:
@@ -610,10 +619,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
610619
except (ValueError, TypeError):
611620
return None
612621

613-
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
622+
def save_wrapped_fun_sourceinfo(wrapper: Callable,
623+
wrapped: Callable | core.DebugInfo | None) -> None:
614624
# Prefer this to functools.wraps because it does not create a reference to
615625
# the wrapped function.
616-
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
626+
if isinstance(wrapped, core.DebugInfo):
627+
func_src_info = wrapped.func_src_info
628+
elif callable(wrapped):
629+
func_src_info = fun_sourceinfo(wrapped)
630+
else:
631+
return
632+
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
617633

618634
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
619635

@@ -664,12 +680,13 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
664680
except (ValueError, TypeError):
665681
pass
666682
else:
667-
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
683+
return tuple(f'{name}{lu._clean_keystr_arg_names(path)}'
684+
for name, x in ba.arguments.items()
668685
for path, l in generate_key_paths(x) if l is not static)
669-
args_arg_names = tuple(f'args{keystr(path)}'
686+
args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}'
670687
for path, l in generate_key_paths(args_)
671688
if l is not static)
672-
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
689+
kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}'
673690
for path, l in generate_key_paths(kwargs_)
674691
if l is not static)
675692
arg_names = args_arg_names + kwargs_arg_names

jax/_src/callback.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from jax._src.interpreters import mlir
3636
from jax._src.lax import lax
3737
from jax._src.lax.control_flow.loops import map as lax_map
38+
from jax._src.lax.control_flow.loops import scan
3839
from jax._src.lib import xla_client as xc
3940
from jax._src.sharding_impls import SingleDeviceSharding
4041
from jax._src.typing import DeprecatedArg
@@ -163,7 +164,10 @@ def callback_batching_rule(
163164

164165
# For FFI calls we must update the layouts. We handle the output layouts
165166
# here, but the input layout updates depend on the vmap_method parameter.
166-
if vmap_method != "sequential" and kwargs.get("output_layouts") is not None:
167+
if (
168+
vmap_method not in ("sequential", "sequential_unrolled") and
169+
kwargs.get("output_layouts") is not None
170+
):
167171
kwargs["output_layouts"] = tuple(
168172
None if layout is None else tuple(n + 1 for n in layout) + (0,)
169173
for layout in kwargs["output_layouts"])
@@ -199,7 +203,7 @@ def callback_batching_rule(
199203
result_avals=batched_result_avals,
200204
**kwargs,
201205
)
202-
elif vmap_method == "sequential":
206+
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
203207
is_batched = [d is not batching.not_mapped for d in dims]
204208
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
205209
def _batch_fun(batched_args):
@@ -211,12 +215,14 @@ def _batch_fun(batched_args):
211215
vmap_method=vmap_method,
212216
**kwargs,
213217
)
214-
outvals = lax_map(_batch_fun, batched_args)
218+
unroll = vmap_method == "sequential_unrolled"
219+
g = lambda _, x: ((), _batch_fun(x))
220+
_, outvals = scan(g, (), batched_args, unroll=unroll)
215221
else:
216222
raise NotImplementedError(
217223
f"vmap is only supported for the {prim.name} primitive when vmap_method "
218-
"is one of 'sequential', 'expand_dims', 'broadcast_all', or "
219-
"'legacy_vectorized'.")
224+
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
225+
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
220226
return tuple(outvals), (0,) * len(outvals)
221227

222228

@@ -371,6 +377,8 @@ def pure_callback(
371377
is deprecated and it will eventually raise ``NotImplementedError``.
372378
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
373379
the batched arguments, calling ``callback`` once for each batch element.
380+
* ``vmap_method="sequential_unrolled"`` is like ``sequential``, but the loop
381+
is unrolled.
374382
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
375383
added as the leading dimension unbatched inputs.
376384
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the
@@ -459,8 +467,8 @@ def pure_callback(
459467
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
460468
"be used together. Please use the vmap_method argument.")
461469
vmap_method = "legacy_vectorized" if vectorized else "sequential"
462-
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
463-
"legacy_vectorized", None]
470+
allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims",
471+
"broadcast_all", "legacy_vectorized", None]
464472
if vmap_method not in allowed_vmap_methods:
465473
raise ValueError(
466474
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "

jax/_src/checkify.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def new_body_f(*c_consts_and_vals):
833833
# This checks if the next cond application will error
834834
_ = cond_f(*c_consts, *out)
835835
return out
836-
new_body_f_ = lu.wrap_init(new_body_f)
836+
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
837837
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
838838
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
839839
*body_jaxpr.in_avals])
@@ -952,7 +952,8 @@ def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
952952

953953

954954
def shard_map_error_check(
955-
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
955+
error: Error, enabled_errors, *vals_in,
956+
jaxpr: core.Jaxpr, in_names, out_names, **kwargs
956957
):
957958
if (mesh := kwargs.get('mesh')) is None:
958959
raise ValueError('Mesh must be provided for shard_map with checkify.')
@@ -976,7 +977,6 @@ def shard_map_error_check(
976977
)
977978
num_out_error_vals = out_tree.num_leaves - len(out_names)
978979

979-
@lu.wrap_init
980980
def expand_errors_leading_dim(*xs):
981981
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
982982
errs, outs = split_list(outs, [num_out_error_vals])
@@ -985,15 +985,18 @@ def expand_errors_leading_dim(*xs):
985985

986986
with core.extend_axis_env_nd(mesh.shape.items()):
987987
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
988-
expand_errors_leading_dim, checked_jaxpr.in_avals
988+
lu.wrap_init(expand_errors_leading_dim,
989+
debug_info=checked_jaxpr.jaxpr.debug_info),
990+
checked_jaxpr.in_avals
989991
)
990992
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
991993

992994
# Update shard_map params to account for extra error values.
993995
# Use fully sharded partitioning for out errors.
994996
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
995997
subfun = lu.hashable_partial(
996-
lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts
998+
lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info),
999+
checked_jaxpr.jaxpr, checked_jaxpr.consts
9971000
)
9981001
new_params = dict(
9991002
jaxpr=checked_jaxpr.jaxpr,
@@ -1007,8 +1010,10 @@ def expand_errors_leading_dim(*xs):
10071010
return tree_unflatten(out_tree, err_and_out)
10081011
error_checks[shard_map.shard_map_p] = shard_map_error_check
10091012

1010-
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
1011-
jvp_jaxpr_thunk, call_jaxpr, **params):
1013+
def custom_jvp_call_rule(in_err: Error,
1014+
enabled_errors: set, *in_vals, num_consts,
1015+
jvp_jaxpr_fun: lu.WrappedFun,
1016+
call_jaxpr: core.ClosedJaxpr, **params):
10121017
# The types to have in mind are:
10131018
# jvp : (a -> b) -> (a, T a) -> (b, T b)
10141019
# checkify : (a -> b) -> a -> Err b
@@ -1021,10 +1026,11 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
10211026
err_vals, err_tree = jtu.tree_flatten(in_err)
10221027
partial_checkify = lu.wrap_init(
10231028
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
1024-
call_jaxpr.consts, enabled_errors, err_tree))
1029+
call_jaxpr.consts, enabled_errors, err_tree),
1030+
debug_info=call_jaxpr.jaxpr.debug_info)
10251031
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
10261032
partial_checkify)
1027-
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
1033+
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun)
10281034
jvp, jvp_out_tree = flatten_fun_output(jvp)
10291035
all_outs = custom_derivatives.custom_jvp_call_p.bind(
10301036
partial_checkify, jvp, *err_vals, *in_vals, **params)
@@ -1041,17 +1047,17 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
10411047

10421048
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
10431049
# outputs that checkify adds (just forwarding the error data's primal and
1044-
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
1050+
# tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
10451051
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
10461052
# Adding another layer of lu.transformation was tricky, though maybe doable.
1047-
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
1048-
@lu.wrap_init
1053+
def lift_jvp(num_errs: int, num_consts: int,
1054+
jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
10491055
def jvp(*xs):
10501056
n, ragged = divmod(len(xs), 2)
10511057
assert not ragged
10521058
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
10531059
zeros = [type(t) is SymbolicZero for t in tangents]
1054-
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
1060+
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
10551061
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
10561062
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
10571063
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
@@ -1063,7 +1069,7 @@ def jvp(*xs):
10631069
primal_errs = xs[num_consts:num_consts+num_errs]
10641070
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
10651071
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
1066-
return jvp
1072+
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
10671073

10681074
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
10691075
fun_jaxpr: core.ClosedJaxpr,

jax/_src/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,8 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
640640
"to handle custom_jvp primitives")
641641
raise NotImplementedError(msg)
642642

643-
def process_custom_transpose(self, prim, call, tracers, **params):
643+
def process_custom_transpose(self, prim: Primitive,
644+
call: lu.WrappedFun, tracers, **params):
644645
msg = (f"{type(self)} must override process_custom_transpose "
645646
"to handle custom_transpose_call primitives")
646647
raise NotImplementedError(msg)

jax/_src/custom_batching.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,25 +141,28 @@ def def_vmap(
141141

142142
@traceback_util.api_boundary
143143
def __call__(self, *args, **kwargs):
144+
debug_fun = api_util.debug_info("custom_vmap fun", self.fun,
145+
args, kwargs)
144146
args = api_util.resolve_kwargs(self.fun, args, kwargs)
145-
fun_name = getattr(self.fun, "__name__", str(self.fun))
146147
if not self.vmap_rule:
147148
raise AttributeError(
148-
f"No batching rule defined for custom_vmap function {fun_name} "
149+
f"No batching rule defined for custom_vmap function {debug_fun.func_name} "
149150
"using def_vmap.")
150-
debug = api_util.debug_info("custom_vmap", self.fun, args, {})
151151
args_flat, in_tree = tree_flatten(args)
152152
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
153-
lu.wrap_init(self.fun, debug_info=debug),
153+
lu.wrap_init(self.fun, debug_info=debug_fun),
154154
in_tree)
155155
in_avals = [core.get_aval(x) for x in args_flat]
156156
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
157157
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
158158
in_tree = treedef_tuple((tree_structure(consts), in_tree))
159159
assert self.vmap_rule is not None
160+
debug_rule = api_util.debug_info("custom_vmap rule", self.vmap_rule,
161+
(0, args, args), {})
160162
out_flat = custom_vmap_p.bind(*consts, *args_flat,
161163
call=closed_call,
162-
rule=ClosedRule(self.vmap_rule),
164+
rule=ClosedRule(self.vmap_rule,
165+
debug_rule),
163166
in_tree=in_tree,
164167
out_tree=out_tree())
165168
return tree_unflatten(out_tree(), out_flat)
@@ -170,9 +173,10 @@ def __call__(self, *args, **kwargs):
170173
# Define a class, instead of making a function closing over `rule`, so
171174
# that we can override __str__
172175
class ClosedRule:
173-
def __init__(self, rule):
176+
def __init__(self, rule: Callable, debug: core.DebugInfo):
174177
functools.update_wrapper(self, rule)
175178
self.rule = rule
179+
self.debug = debug
176180

177181
def __call__(self, axis_size, all_in_batched, *all_args):
178182
_, args = all_args
@@ -252,8 +256,11 @@ def custom_vmap_abstract_eval(*in_avals, call, **_):
252256
return call.out_avals
253257

254258

255-
def custom_vmap_jvp(primals, tangents, *, call, rule, in_tree, out_tree):
256-
def jvp_of_rule_rule(axis_size, in_batched, primals, tangents):
259+
def custom_vmap_jvp(primals, tangents, *,
260+
call: core.ClosedJaxpr,
261+
rule: ClosedRule,
262+
in_tree: tree_util.PyTreeDef, out_tree: tree_util.PyTreeDef):
263+
def jvp_of_rule_rule(axis_size: int, in_batched, primals, tangents):
257264
in_batched_ps, in_batched_ts = in_batched
258265

259266
mutually_batched = tree_map(operator.and_, in_batched_ps, in_batched_ts)
@@ -281,11 +288,14 @@ def to_jvp(*primals):
281288
out_mutually_batched.store(out_batched)
282289
return out
283290

291+
api_util.save_wrapped_fun_sourceinfo(to_jvp, call.jaxpr.debug_info)
284292
def to_vmap_over_extra_batched_dims(primals, tangents):
285293
return api.jvp(to_jvp, primals, tangents)
286294

287295
to_vmap_over_extra_batched_dims_flat, out_tree2 = api_util.flatten_fun_nokwargs(
288-
lu.wrap_init(to_vmap_over_extra_batched_dims),
296+
lu.wrap_init(to_vmap_over_extra_batched_dims,
297+
# TODO(necula): fix the debug_info calling convention
298+
debug_info=call.jaxpr.debug_info),
289299
tree_ps_ts)
290300

291301
flat_out_ps_ts, flat_out_axes = vmap_unrestricted(

0 commit comments

Comments
 (0)