Skip to content

Commit 65307ab

Browse files
Merge pull request jax-ml#24370 from dfm:ffi-call-to-callable
PiperOrigin-RevId: 688188390
2 parents ad53add + 0b651f0 commit 65307ab

File tree

6 files changed

+113
-114
lines changed

6 files changed

+113
-114
lines changed

docs/ffi.ipynb

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -287,27 +287,27 @@
287287
" if x.dtype != jnp.float32:\n",
288288
" raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n",
289289
"\n",
290-
" # In this case, the output of our FFI function is just a single array with the\n",
291-
" # same shape and dtype as the input. We discuss a case with a more interesting\n",
292-
" # output type below.\n",
293-
" out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
294-
"\n",
295-
" return jex.ffi.ffi_call(\n",
290+
" call = jex.ffi.ffi_call(\n",
296291
" # The target name must be the same string as we used to register the target\n",
297292
" # above in `register_custom_call_target`\n",
298293
" \"rms_norm\",\n",
299-
" out_type,\n",
300-
" x,\n",
301-
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
302-
" # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n",
303-
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
304-
" # static parameter (i.e. not a JAX array).\n",
305-
" eps=np.float32(eps),\n",
294+
"\n",
295+
" # In this case, the output of our FFI function is just a single array with\n",
296+
" # the same shape and dtype as the input. We discuss a case with a more\n",
297+
" # interesting output type below.\n",
298+
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
299+
"\n",
306300
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
307301
" # as discussed below.\n",
308302
" vmap_method=\"broadcast_fullrank\",\n",
309303
" )\n",
310304
"\n",
305+
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
306+
" # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n",
307+
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
308+
" # static parameter (i.e. not a JAX array).\n",
309+
" return call(x, eps=np.float32(eps))\n",
310+
"\n",
311311
"\n",
312312
"# Test that this gives the same result as our reference implementation\n",
313313
"x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))\n",
@@ -403,10 +403,8 @@
403403
" return jex.ffi.ffi_call(\n",
404404
" \"rms_norm\",\n",
405405
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
406-
" x,\n",
407-
" eps=np.float32(eps),\n",
408406
" vmap_method=\"sequential\",\n",
409-
" )\n",
407+
" )(x, eps=np.float32(eps))\n",
410408
"\n",
411409
"\n",
412410
"jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)"
@@ -462,10 +460,8 @@
462460
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
463461
" jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n",
464462
" ),\n",
465-
" x,\n",
466-
" eps=np.float32(eps),\n",
467463
" vmap_method=\"broadcast_fullrank\",\n",
468-
" )\n",
464+
" )(x, eps=np.float32(eps))\n",
469465
" return y, (res, x)\n",
470466
"\n",
471467
"\n",
@@ -478,11 +474,8 @@
478474
" jex.ffi.ffi_call(\n",
479475
" \"rms_norm_bwd\",\n",
480476
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
481-
" res,\n",
482-
" x,\n",
483-
" ct,\n",
484-
" vmap_method=\"broadcast_fullrank\",\n",
485-
" ),\n",
477+
" vmap_method=\"broadcast_fullrank\",\n",
478+
" )(res, x, ct),\n",
486479
" )\n",
487480
"\n",
488481
"\n",
@@ -569,10 +562,8 @@
569562
" return lambda x: jex.ffi.ffi_call(\n",
570563
" target_name,\n",
571564
" out_type,\n",
572-
" x,\n",
573-
" eps=np.float32(eps),\n",
574565
" vmap_method=\"broadcast_fullrank\",\n",
575-
" )\n",
566+
" )(x, eps=np.float32(eps))\n",
576567
"\n",
577568
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",
578569
"\n",

docs/ffi.md

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -248,27 +248,27 @@ def rms_norm(x, eps=1e-5):
248248
if x.dtype != jnp.float32:
249249
raise ValueError("Only the float32 dtype is implemented by rms_norm")
250250
251-
# In this case, the output of our FFI function is just a single array with the
252-
# same shape and dtype as the input. We discuss a case with a more interesting
253-
# output type below.
254-
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
255-
256-
return jex.ffi.ffi_call(
251+
call = jex.ffi.ffi_call(
257252
# The target name must be the same string as we used to register the target
258253
# above in `register_custom_call_target`
259254
"rms_norm",
260-
out_type,
261-
x,
262-
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
263-
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
264-
# type (which corresponds to numpy's `float32` type), and it must be a
265-
# static parameter (i.e. not a JAX array).
266-
eps=np.float32(eps),
255+
256+
# In this case, the output of our FFI function is just a single array with
257+
# the same shape and dtype as the input. We discuss a case with a more
258+
# interesting output type below.
259+
jax.ShapeDtypeStruct(x.shape, x.dtype),
260+
267261
# The `vmap_method` parameter controls this function's behavior under `vmap`
268262
# as discussed below.
269263
vmap_method="broadcast_fullrank",
270264
)
271265
266+
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
267+
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
268+
# type (which corresponds to numpy's `float32` type), and it must be a
269+
# static parameter (i.e. not a JAX array).
270+
return call(x, eps=np.float32(eps))
271+
272272
273273
# Test that this gives the same result as our reference implementation
274274
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
@@ -334,10 +334,8 @@ def rms_norm_sequential(x, eps=1e-5):
334334
return jex.ffi.ffi_call(
335335
"rms_norm",
336336
jax.ShapeDtypeStruct(x.shape, x.dtype),
337-
x,
338-
eps=np.float32(eps),
339337
vmap_method="sequential",
340-
)
338+
)(x, eps=np.float32(eps))
341339
342340
343341
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
@@ -380,10 +378,8 @@ def rms_norm_fwd(x, eps=1e-5):
380378
jax.ShapeDtypeStruct(x.shape, x.dtype),
381379
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
382380
),
383-
x,
384-
eps=np.float32(eps),
385381
vmap_method="broadcast_fullrank",
386-
)
382+
)(x, eps=np.float32(eps))
387383
return y, (res, x)
388384
389385
@@ -396,11 +392,8 @@ def rms_norm_bwd(eps, res, ct):
396392
jex.ffi.ffi_call(
397393
"rms_norm_bwd",
398394
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
399-
res,
400-
x,
401-
ct,
402-
vmap_method="broadcast_fullrank",
403-
),
395+
vmap_method="broadcast_fullrank",
396+
)(res, x, ct),
404397
)
405398
406399
@@ -477,10 +470,8 @@ def rms_norm_cross_platform(x, eps=1e-5):
477470
return lambda x: jex.ffi.ffi_call(
478471
target_name,
479472
out_type,
480-
x,
481-
eps=np.float32(eps),
482473
vmap_method="broadcast_fullrank",
483-
)
474+
)(x, eps=np.float32(eps))
484475
485476
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
486477

examples/ffi/src/jax_ffi_example/attrs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,11 @@ def array_attr(num: int):
3535
return jex.ffi.ffi_call(
3636
"array_attr",
3737
jax.ShapeDtypeStruct((), np.int32),
38-
array=np.arange(num, dtype=np.int32),
39-
)
38+
)(array=np.arange(num, dtype=np.int32))
4039

4140

4241
def dictionary_attr(**kwargs):
4342
return jex.ffi.ffi_call(
4443
"dictionary_attr",
4544
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
46-
**kwargs,
47-
)
45+
)(**kwargs)

examples/ffi/src/jax_ffi_example/rms_norm.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,17 @@ def rms_norm(x, eps=1e-5):
4949
# same shape and dtype as the input.
5050
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
5151

52+
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
53+
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
54+
# type (which corresponds to numpy's `float32` type), and it must be a
55+
# static parameter (i.e. not a JAX array).
5256
return jex.ffi.ffi_call(
5357
# The target name must be the same string as we used to register the target
5458
# above in `register_ffi_target`
5559
"rms_norm",
5660
out_type,
57-
x,
58-
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
59-
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
60-
# type (which corresponds to numpy's `float32` type), and it must be a
61-
# static parameter (i.e. not a JAX array).
62-
eps=np.float32(eps),
6361
vmap_method="broadcast_fullrank",
64-
)
62+
)(x, eps=np.float32(eps))
6563

6664

6765
def rms_norm_fwd(x, eps=1e-5):
@@ -71,10 +69,8 @@ def rms_norm_fwd(x, eps=1e-5):
7169
jax.ShapeDtypeStruct(x.shape, x.dtype),
7270
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
7371
),
74-
x,
75-
eps=np.float32(eps),
7672
vmap_method="broadcast_fullrank",
77-
)
73+
)(x, eps=np.float32(eps))
7874
return y, (res, x)
7975

8076

@@ -87,11 +83,8 @@ def rms_norm_bwd(eps, res, ct):
8783
jex.ffi.ffi_call(
8884
"rms_norm_bwd",
8985
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
90-
res,
91-
x,
92-
ct,
9386
vmap_method="broadcast_fullrank",
94-
),
87+
)(res, x, ct),
9588
)
9689

9790

0 commit comments

Comments
 (0)