Skip to content

Commit 2d44f98

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Finalize deprecation of ffi_call with inline arguments.
PiperOrigin-RevId: 745261995
1 parent 09fed2f commit 2d44f98

File tree

3 files changed

+10
-34
lines changed

3 files changed

+10
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4242
available from `jax.extend.mlir`.
4343
* `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by
4444
{mod}`jax.ffi` should be used instead.
45+
* The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no
46+
longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a
47+
callable.
4548
* Several previously-deprecated APIs have been removed, including:
4649
* From `jax.lib.xla_client`: `FftType`, `PaddingType`, `dtype_to_etype`,
4750
and `shape_from_pyval`.

jax/_src/ffi.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
4040
Shape)
4141

42-
# TODO(dfm): Remove after 6 months or less because there aren't any offical
43-
# compatibility guarantees for jax.extend (see JEP 15856)
44-
# Added Oct 13, 2024
45-
deprecations.register("jax-ffi-call-args")
46-
4742
map, unsafe_map = util.safe_map, map
4843
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
4944

@@ -325,7 +320,7 @@ def _convert_layouts_for_ffi_call(
325320
def ffi_call(
326321
target_name: str,
327322
result_shape_dtypes: ResultMetadata,
328-
*deprecated_args: ArrayLike,
323+
*,
329324
has_side_effect: bool = ...,
330325
vmap_method: str | None = ...,
331326
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
@@ -334,16 +329,15 @@ def ffi_call(
334329
custom_call_api_version: int = ...,
335330
legacy_backend_config: str | None = ...,
336331
vectorized: bool | DeprecatedArg = ...,
337-
**deprecated_kwargs: Any,
338-
) -> Callable[..., Array] | Array:
332+
) -> Callable[..., Array]:
339333
...
340334

341335

342336
@overload
343337
def ffi_call(
344338
target_name: str,
345339
result_shape_dtypes: Sequence[ResultMetadata],
346-
*deprecated_args: ArrayLike,
340+
*,
347341
has_side_effect: bool = ...,
348342
vmap_method: str | None = ...,
349343
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
@@ -352,15 +346,14 @@ def ffi_call(
352346
custom_call_api_version: int = ...,
353347
legacy_backend_config: str | None = ...,
354348
vectorized: bool | DeprecatedArg = ...,
355-
**deprecated_kwargs: Any,
356-
) -> Callable[..., Sequence[Array]] | Sequence[Array]:
349+
) -> Callable[..., Sequence[Array]]:
357350
...
358351

359352

360353
def ffi_call(
361354
target_name: str,
362355
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
363-
*deprecated_args: ArrayLike,
356+
*,
364357
has_side_effect: bool = False,
365358
vmap_method: str | None = None,
366359
input_layouts: Sequence[FfiLayoutOptions] | None = None,
@@ -369,8 +362,7 @@ def ffi_call(
369362
custom_call_api_version: int = 4,
370363
legacy_backend_config: str | None = None,
371364
vectorized: bool | DeprecatedArg = DeprecatedArg(),
372-
**deprecated_kwargs: Any,
373-
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
365+
) -> Callable[..., Array | Sequence[Array]]:
374366
"""Call a foreign function interface (FFI) target.
375367
376368
See the :ref:`ffi-tutorial` tutorial for more information.
@@ -537,19 +529,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any):
537529
else:
538530
return results[0]
539531

540-
if deprecated_args or deprecated_kwargs:
541-
deprecations.warn(
542-
"jax-ffi-call-args",
543-
"Calling ffi_call directly with input arguments is deprecated. "
544-
"Instead, ffi_call should be used to construct a callable, which can "
545-
"then be called with the appropriate inputs. For example,\n"
546-
" ffi_call('target_name', output_type, x, argument=5)\n"
547-
"should be replaced with\n"
548-
" ffi_call('target_name', output_type)(x, argument=5)",
549-
stacklevel=2)
550-
return wrapped(*deprecated_args, **deprecated_kwargs)
551-
else:
552-
return wrapped
532+
return wrapped
553533

554534

555535
# ffi_call must support some small non-hashable input arguments, like np.arrays

tests/ffi_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,6 @@ def test_vectorized_deprecation(self):
208208
with self.assertWarns(DeprecationWarning):
209209
jax.vmap(ffi_call_geqrf)(x)
210210

211-
def test_backward_compat_syntax(self):
212-
def fun(x):
213-
return jax.ffi.ffi_call("test_ffi", x, x, param=0.5)
214-
msg = "Calling ffi_call directly with input arguments is deprecated"
215-
with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg):
216-
jax.jit(fun).lower(jnp.ones(5))
217-
218211
def test_input_output_aliases(self):
219212
def fun(x):
220213
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)

0 commit comments

Comments
 (0)