Skip to content

Commit 41868ef

Browse files
committed
format
1 parent 5ddec65 commit 41868ef

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

jax/_src/nn/functions.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,11 +1218,11 @@ def scaled_matmul(
12181218
) -> Array:
12191219
r"""Scaled matrix multiplication function.
12201220
1221-
Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
1221+
Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
12221222
The last dim is the contracting dim, and block size is inferred.
12231223
12241224
Mathematically, this operation is equivalent to::
1225-
1225+
12261226
a_block_size = a.shape[-1] // a_scales.shape[-1]
12271227
b_block_size = b.shape[-1] // b_scales.shape[-1]
12281228
a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
@@ -1258,26 +1258,26 @@ def scaled_matmul(
12581258
12591259
Basic case:
12601260
1261-
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
1262-
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
1263-
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
1264-
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
1265-
>>> scaled_matmul(a, b, a_scales, b_scales)
1266-
Array([[[8.]]], dtype=float32)
1267-
1261+
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
1262+
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
1263+
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
1264+
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
1265+
>>> scaled_matmul(a, b, a_scales, b_scales)
1266+
Array([[[8.]]], dtype=float32)
1267+
12681268
Using fused cuDNN call on Blackwell GPUs:
12691269
1270-
>>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1271-
>>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1272-
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1273-
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1274-
>>> scaled_matmul(a, b, a_scales, b_scales)
1270+
>>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1271+
>>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1272+
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1273+
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1274+
>>> scaled_matmul(a, b, a_scales, b_scales)
12751275
"""
12761276
if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)):
12771277
raise ValueError(
12781278
"scaled_matmul requires all inputs to be 3-dimensional arrays"
12791279
)
1280-
1280+
12811281
B_a, M_a, K_a = a.shape
12821282
B_b, N_b, K_b = b.shape
12831283
if K_a != K_b or B_a != B_b:
@@ -1286,7 +1286,7 @@ def scaled_matmul(
12861286
f"and contract (K) dimensions, but got shapes {a.shape} and "
12871287
f"{b.shape}"
12881288
)
1289-
1289+
12901290
B_as, M_as, K_as = a_scales.shape
12911291
B_bs, N_bs, K_bs = b_scales.shape
12921292
if K_as != K_bs or B_as != B_bs:
@@ -1295,7 +1295,7 @@ def scaled_matmul(
12951295
f"contract (K) dimensions, but got shapes {a_scales.shape} and "
12961296
f"{b_scales.shape}"
12971297
)
1298-
1298+
12991299
if M_as != M_a or N_bs != N_b:
13001300
raise ValueError(
13011301
"scaled_matmul requires scales to match non-contract dimensions of "
@@ -1378,7 +1378,7 @@ def scaled_dot_general(
13781378
lhs, rhs, and gradients. Users can obtain valid configurations via
13791379
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
13801380
are supported. If `None`, falls back to `lax.dot_general`.
1381-
1381+
13821382
Returns:
13831383
Array: The resulting tensor, with batch dimensions first, followed by
13841384
non-contracting/non-batch dimensions of lhs, and then those of rhs.
@@ -1405,6 +1405,7 @@ def scaled_dot_general(
14051405
14061406
Using scaled_dot_general with the configs:
14071407
1408+
>>> import functools
14081409
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
14091410
>>> lhs = random.normal(keys[0], (3, 128, 64))
14101411
>>> rhs = random.normal(keys[1], (3, 128, 64))

0 commit comments

Comments
 (0)