3939from jax ._src .typing import (Array , ArrayLike , DeprecatedArg , DuckTypedArray ,
4040 Shape )
4141
42- # TODO(dfm): Remove after 6 months or less because there aren't any offical
43- # compatibility guarantees for jax.extend (see JEP 15856)
44- # Added Oct 13, 2024
45- deprecations .register ("jax-ffi-call-args" )
46-
4742map , unsafe_map = util .safe_map , map
4843FfiLayoutOptions = Sequence [int ] | DeviceLocalLayout | None
4944
@@ -325,7 +320,7 @@ def _convert_layouts_for_ffi_call(
325320def ffi_call (
326321 target_name : str ,
327322 result_shape_dtypes : ResultMetadata ,
328- * deprecated_args : ArrayLike ,
323+ * ,
329324 has_side_effect : bool = ...,
330325 vmap_method : str | None = ...,
331326 input_layouts : Sequence [FfiLayoutOptions ] | None = ...,
@@ -334,16 +329,15 @@ def ffi_call(
334329 custom_call_api_version : int = ...,
335330 legacy_backend_config : str | None = ...,
336331 vectorized : bool | DeprecatedArg = ...,
337- ** deprecated_kwargs : Any ,
338- ) -> Callable [..., Array ] | Array :
332+ ) -> Callable [..., Array ]:
339333 ...
340334
341335
342336@overload
343337def ffi_call (
344338 target_name : str ,
345339 result_shape_dtypes : Sequence [ResultMetadata ],
346- * deprecated_args : ArrayLike ,
340+ * ,
347341 has_side_effect : bool = ...,
348342 vmap_method : str | None = ...,
349343 input_layouts : Sequence [FfiLayoutOptions ] | None = ...,
@@ -352,15 +346,14 @@ def ffi_call(
352346 custom_call_api_version : int = ...,
353347 legacy_backend_config : str | None = ...,
354348 vectorized : bool | DeprecatedArg = ...,
355- ** deprecated_kwargs : Any ,
356- ) -> Callable [..., Sequence [Array ]] | Sequence [Array ]:
349+ ) -> Callable [..., Sequence [Array ]]:
357350 ...
358351
359352
360353def ffi_call (
361354 target_name : str ,
362355 result_shape_dtypes : ResultMetadata | Sequence [ResultMetadata ],
363- * deprecated_args : ArrayLike ,
356+ * ,
364357 has_side_effect : bool = False ,
365358 vmap_method : str | None = None ,
366359 input_layouts : Sequence [FfiLayoutOptions ] | None = None ,
@@ -369,8 +362,7 @@ def ffi_call(
369362 custom_call_api_version : int = 4 ,
370363 legacy_backend_config : str | None = None ,
371364 vectorized : bool | DeprecatedArg = DeprecatedArg (),
372- ** deprecated_kwargs : Any ,
373- ) -> Callable [..., Array | Sequence [Array ]] | Array | Sequence [Array ]:
365+ ) -> Callable [..., Array | Sequence [Array ]]:
374366 """Call a foreign function interface (FFI) target.
375367
376368 See the :ref:`ffi-tutorial` tutorial for more information.
@@ -537,19 +529,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any):
537529 else :
538530 return results [0 ]
539531
540- if deprecated_args or deprecated_kwargs :
541- deprecations .warn (
542- "jax-ffi-call-args" ,
543- "Calling ffi_call directly with input arguments is deprecated. "
544- "Instead, ffi_call should be used to construct a callable, which can "
545- "then be called with the appropriate inputs. For example,\n "
546- " ffi_call('target_name', output_type, x, argument=5)\n "
547- "should be replaced with\n "
548- " ffi_call('target_name', output_type)(x, argument=5)" ,
549- stacklevel = 2 )
550- return wrapped (* deprecated_args , ** deprecated_kwargs )
551- else :
552- return wrapped
532+ return wrapped
553533
554534
555535# ffi_call must support some small non-hashable input arguments, like np.arrays
0 commit comments