@@ -1218,11 +1218,11 @@ def scaled_matmul(
12181218) -> Array :
12191219 r"""Scaled matrix multiplication function.
12201220
1221- Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
1221+ Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
12221222 The last dim is the contracting dim, and block size is inferred.
12231223
12241224 Mathematically, this operation is equivalent to::
1225-
1225+
12261226 a_block_size = a.shape[-1] // a_scales.shape[-1]
12271227 b_block_size = b.shape[-1] // b_scales.shape[-1]
12281228 a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
@@ -1258,26 +1258,26 @@ def scaled_matmul(
12581258
12591259 Basic case:
12601260
1261- >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
1262- >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
1263- >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
1264- >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
1265- >>> scaled_matmul(a, b, a_scales, b_scales)
1266- Array([[[8.]]], dtype=float32)
1267-
1261+ >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
1262+ >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
1263+ >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
1264+ >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
1265+ >>> scaled_matmul(a, b, a_scales, b_scales)
1266+ Array([[[8.]]], dtype=float32)
1267+
12681268 Using fused cuDNN call on Blackwell GPUs:
12691269
1270- >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1271- >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1272- >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1273- >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1274- >>> scaled_matmul(a, b, a_scales, b_scales)
1270+ >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1271+ >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1272+ >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1273+ >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1274+ >>> scaled_matmul(a, b, a_scales, b_scales)
12751275 """
12761276 if not all (x .ndim == 3 for x in (a , b , a_scales , b_scales )):
12771277 raise ValueError (
12781278 "scaled_matmul requires all inputs to be 3-dimensional arrays"
12791279 )
1280-
1280+
12811281 B_a , M_a , K_a = a .shape
12821282 B_b , N_b , K_b = b .shape
12831283 if K_a != K_b or B_a != B_b :
@@ -1286,7 +1286,7 @@ def scaled_matmul(
12861286 f"and contract (K) dimensions, but got shapes { a .shape } and "
12871287 f"{ b .shape } "
12881288 )
1289-
1289+
12901290 B_as , M_as , K_as = a_scales .shape
12911291 B_bs , N_bs , K_bs = b_scales .shape
12921292 if K_as != K_bs or B_as != B_bs :
@@ -1295,7 +1295,7 @@ def scaled_matmul(
12951295 f"contract (K) dimensions, but got shapes { a_scales .shape } and "
12961296 f"{ b_scales .shape } "
12971297 )
1298-
1298+
12991299 if M_as != M_a or N_bs != N_b :
13001300 raise ValueError (
13011301 "scaled_matmul requires scales to match non-contract dimensions of "
@@ -1378,7 +1378,7 @@ def scaled_dot_general(
13781378 lhs, rhs, and gradients. Users can obtain valid configurations via
13791379 `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
13801380 are supported. If `None`, falls back to `lax.dot_general`.
1381-
1381+
13821382 Returns:
13831383 Array: The resulting tensor, with batch dimensions first, followed by
13841384 non-contracting/non-batch dimensions of lhs, and then those of rhs.
@@ -1405,6 +1405,7 @@ def scaled_dot_general(
14051405
14061406 Using scaled_dot_general with the configs:
14071407
1408+ >>> import functools
14081409 >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
14091410 >>> lhs = random.normal(keys[0], (3, 128, 64))
14101411 >>> rhs = random.normal(keys[1], (3, 128, 64))
0 commit comments