Skip to content

Commit fcf0b6d

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Add _raw_platform to work around extra platform normalization logic and enable
GPU aot compilation without a GPU present. Fixes jax-ml#23971 PiperOrigin-RevId: 702506848
1 parent ceeed90 commit fcf0b6d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2220,7 +2220,10 @@ def lower_sharding_computation(
22202220
out_shardings = _concretize_abstract_shardings(
22212221
out_shardings, global_out_avals, device_assignment)
22222222

2223-
platforms = lowering_platforms or (backend.platform,)
2223+
# TODO(parkers): One _raw_platform has been unified with platform,
2224+
# change this back to just read platform.
2225+
platforms = lowering_platforms or (
2226+
getattr(backend, "_raw_platform", backend.platform),)
22242227

22252228
committed = bool(
22262229
devices_from_context or

0 commit comments

Comments
 (0)