Skip to content

Commit 2aaa108

Browse files
committed
Raise an error when registering a lowering for an unknown platform
1 parent 801fe87 commit 2aaa108

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,12 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
834834
if platform is None:
835835
_lowerings[prim] = rule
836836
else:
837+
if not xb.is_known_platform(platform):
838+
known_platforms = sorted(xb.known_platforms())
839+
raise NotImplementedError(
840+
f"Registering an MLIR lowering rule for primitive {prim}"
841+
f" for an unknown platform {platform}. Known platforms are:"
842+
f" {', '.join(known_platforms)}.")
837843
# For backward compatibility reasons, we allow rules to be registered
838844
# under "gpu" even though the platforms are now called "cuda" and "rocm".
839845
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove

jax/_src/xla_bridge.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ class BackendRegistration:
202202
# for unimplemented features. Wrong outputs are not acceptable.
203203
_nonexperimental_plugins: set[str] = {'cuda', 'rocm'}
204204

205+
# The set of known experimental plugins that have registrations in JAX codebase.
206+
_experimental_plugins: set[str] = {"METAL"}
207+
205208
def register_backend_factory(name: str, factory: BackendFactory, *,
206209
priority: int = 0,
207210
fail_quietly: bool = True,
@@ -774,12 +777,20 @@ def _discover_and_register_pjrt_plugins():
774777
_alias_to_platforms.setdefault(_alias, []).append(_platform)
775778

776779

780+
def known_platforms() -> set[str]:
781+
platforms = set()
782+
platforms |= set(_nonexperimental_plugins)
783+
platforms |= set(_experimental_plugins)
784+
platforms |= set(_backend_factories.keys())
785+
platforms |= set(_platform_aliases.values())
786+
return platforms
787+
788+
777789
def is_known_platform(platform: str) -> bool:
778790
# A platform is valid if there is a registered factory for it. It does not
779791
# matter if we were unable to initialize that platform; we only care that
780792
# we've heard of it and it isn't, e.g., a typo.
781-
return (platform in _backend_factories.keys() or
782-
platform in _platform_aliases.keys())
793+
return platform in known_platforms()
783794

784795

785796
def canonicalize_platform(platform: str) -> str:

tests/extend_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,5 +291,15 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs):
291291
)
292292

293293

294+
class MlirRegisterLoweringTest(jtu.JaxTestCase):
295+
296+
def test_unknown_platform_error(self):
297+
with self.assertRaisesRegex(
298+
NotImplementedError,
299+
"Registering an MLIR lowering rule for primitive .+ for an unknown "
300+
"platform foo. Known platforms are: .+."):
301+
mlir.register_lowering(prim=None, rule=None, platform="foo")
302+
303+
294304
if __name__ == "__main__":
295305
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)