Skip to content

Conversation

@rocm-repo-management-api-2
Copy link

Daily sync with upstream

sharadmv and others added 30 commits June 6, 2025 04:54
…ace-time overhead from Exception creation

PiperOrigin-RevId: 768022818
PiperOrigin-RevId: 768037136
…ref`.

The new function `_transformed_smem_ref_type` will be used in a follow up change.

PiperOrigin-RevId: 768045658
We port:
* the class itself
* `__init__`
* `__hash__`
* `__eq__`

which is enough to get a small speedup. We also do not change the representation of the data as two Python tuples for the moment.

PiperOrigin-RevId: 768069350
* Don't spend time on annotation formatting if there are no annotations.
* Remove an assertion that the children of a ConcatDoc are Docs.
* Fuse _align_annotations with the code that produces strings, which saves allocating a NamedTuple per line.
* Don't call .format() to test whether the LHS of a jaxpr equation prints as empty.

PiperOrigin-RevId: 768144149
`jax.extend.backend` allows the user to register a callback that will be called
when JAX backends are cleared. The primary purpose is to let the user clear any
caches that hold a reference to JAX backends transitively (via JAX
`Sharding`/`Mesh`/`Device`) so that it can help destroy cleared backends.

PiperOrigin-RevId: 768173340
…one when we are discharging the ref.

The reason PRNGKeyArray doesn't have a `format` field is we don't know how to create a logical format.dll for it.

PiperOrigin-RevId: 768186349
PiperOrigin-RevId: 768186737
PiperOrigin-RevId: 768187911
* Cache the dtype to short name conversion.
* Use pp.concat when concatenating more than 2 things. This builds a slightly flatter tree of pretty-printer documents.

PiperOrigin-RevId: 768195993
…n a mutable array is closed over

PiperOrigin-RevId: 768220283
PiperOrigin-RevId: 768485986
This yields a significant speedup (3x) when printing large jaxprs.

PiperOrigin-RevId: 768803336
PiperOrigin-RevId: 768815050
Previously we disabled the jax2tf_test for older versions of TF.
Re-enable for 2.19.1 and higher.
…ve_scan_reverse_argument_order

PiperOrigin-RevId: 769139664
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

This required a few local imports and refactors.

PiperOrigin-RevId: 769184594
Jake VanderPlas and others added 28 commits June 18, 2025 13:02
This also bundles-in `ad_checkpoint.py` and `state/*.py` because these have circular dependencies on lax source files.

Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, prevents use of internal APIs, and leads to improved build and iteration times.

PiperOrigin-RevId: 773030035
The issue was that CompilerParams is ambiguous when generating docs from
type annotations; we can fix this by specifying which CompilerParams is intended.
Remove `use_shardy_partitioner` in `get_compile_options`. It can be from the jax.config directly.

PiperOrigin-RevId: 773150657
PiperOrigin-RevId: 773232950
Also do a couple of clean ups.

PiperOrigin-RevId: 773238722
…t_cast`s.

This uncovers a propagation bug, whereby opportunities to propagate replicated
layouts that were not already explicitly annotated as attributes downwards
would be missed, because layout propagation started with a backwards pass. We
now changed the implementation to start with a forward pass.

Some additional edits:

1. I changed the layout in `test_optimization_barrier_op_propagates_user_layouts`.
   Generally, propagating replicated layouts upwards is not a safe thing to do,
   and we should have properly caught that. The upcoming infrastructure will
   recognize such issues, so we don't bother attempting to fix the underlying
   problem here;
2. I got rid of `test_infer_layout_propagates_func_layouts_to_ops` since we no
   longer care about `FuncOp`s.

This simplification will allow the new infrastructure to not concern itself
with `FuncOp`s, on which we were putting inconsistent expectations, and which
would add quite a bit of complexity.

PiperOrigin-RevId: 773277630
… `gpu_layout_inference_test.py`.

This makes checking for layouts more synthetic :)

PiperOrigin-RevId: 773288423
…d replace it with `Format`, `.format`, `.input_formats` and `.output_formats` in JAX

Co-authored-by: Roy Frostig <[email protected]>
PiperOrigin-RevId: 773337503
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner June 20, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) June 20, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.