Skip to content

Commit b8d9e7f

Browse files
Merge pull request jax-ml#27503 from kaixih:enable_doc_scaled_dot
PiperOrigin-RevId: 745322012
2 parents 7b45552 + 41868ef commit b8d9e7f

File tree

4 files changed

+190
-81
lines changed

4 files changed

+190
-81
lines changed

docs/jax.nn.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,6 @@ Other functions
5454
standardize
5555
one_hot
5656
dot_product_attention
57+
scaled_matmul
58+
get_scaled_dot_general_config
59+
scaled_dot_general

jax/_src/nn/functions.py

Lines changed: 181 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,81 +1210,206 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
12101210
return jnp.reshape(out, output_shape)
12111211

12121212
def scaled_matmul(
1213-
lhs: Array,
1214-
rhs: Array,
1215-
lhs_scales: Array,
1216-
rhs_scales: Array,
1213+
a: Array,
1214+
b: Array,
1215+
a_scales: Array,
1216+
b_scales: Array,
12171217
preferred_element_type: DTypeLike = jnp.float32,
12181218
) -> Array:
1219-
r"""
1220-
Performs scaled matrix multiplication between two 3D arrays, with scaling
1221-
factors applied to the matrices.
1222-
.. math::
1223-
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
1219+
r"""Scaled matrix multiplication function.
1220+
1221+
Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
1222+
The last dim is the contracting dim, and block size is inferred.
1223+
1224+
Mathematically, this operation is equivalent to::
1225+
1226+
a_block_size = a.shape[-1] // a_scales.shape[-1]
1227+
b_block_size = b.shape[-1] // b_scales.shape[-1]
1228+
a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
1229+
b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1)
1230+
jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
1231+
12241232
Args:
1225-
lhs (Array): A 3D array of shape (B, M, K).
1226-
rhs (Array): A 3D array of shape (B, N, K).
1227-
lhs_scales (Array): A 3D array of shape (B, M, K_block).
1228-
rhs_scales (Array): A 3D array of shape (B, N, K_block).
1229-
preferred_element_type (DTypeLike, optional): The preferred data type
1230-
for the computation. Defaults to `jnp.float32`.
1233+
a (Array): Shape (B, M, K).
1234+
b (Array): Shape (B, N, K).
1235+
a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`.
1236+
b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`.
1237+
preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
1238+
12311239
Returns:
1232-
Array: A 3D array of shape (B, M, N) representing the scaled matrix
1233-
multiplication result.
1234-
Raises:
1235-
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
1236-
match the number of columns in `rhs` (`rhs_K`).
1240+
Array of shape (B, M, N).
1241+
12371242
Notes:
1238-
- The function ensures that the `preferred_element_type` is
1239-
danonicalized before passing it to the underlying computation.
1240-
- Scaling is applied to the matrices based on the `lhs_scales` and
1241-
`rhs_scales` arrays, enabling efficient computations in blocks.
1243+
- We currently do not support user-defined `precision` for customizing the
1244+
compute data type. It is fixed to `jnp.float32`.
1245+
- Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`.
1246+
- To use cuDNN with Nvidia Blackwell GPUs, inputs must match::
1247+
1248+
# mxfp8
1249+
a, b: jnp.float8_e4m3fn | jnp.float8_e5m2
1250+
a_scales, b_scales: jnp.float8_e8m0fnu
1251+
block_size: 32
1252+
# nvfp4
1253+
a, b: jnp.float4_e2m1fn
1254+
a_scales, b_scales: jnp.float8_e4m3fn
1255+
block_size: 16
1256+
1257+
Examples:
1258+
1259+
Basic case:
1260+
1261+
>>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
1262+
>>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
1263+
>>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
1264+
>>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
1265+
>>> scaled_matmul(a, b, a_scales, b_scales)
1266+
Array([[[8.]]], dtype=float32)
1267+
1268+
Using fused cuDNN call on Blackwell GPUs:
1269+
1270+
>>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1271+
>>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1272+
>>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1273+
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1274+
>>> scaled_matmul(a, b, a_scales, b_scales)
12421275
"""
1243-
B, M, lhs_K = lhs.shape
1244-
_, N, rhs_K = rhs.shape
1245-
assert lhs_K == rhs_K
1246-
_, _, K_block = lhs_scales.shape
1276+
if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)):
1277+
raise ValueError(
1278+
"scaled_matmul requires all inputs to be 3-dimensional arrays"
1279+
)
1280+
1281+
B_a, M_a, K_a = a.shape
1282+
B_b, N_b, K_b = b.shape
1283+
if K_a != K_b or B_a != B_b:
1284+
raise ValueError(
1285+
"scaled_matmul requires inputs a and b to have matching batch (B) "
1286+
f"and contract (K) dimensions, but got shapes {a.shape} and "
1287+
f"{b.shape}"
1288+
)
1289+
1290+
B_as, M_as, K_as = a_scales.shape
1291+
B_bs, N_bs, K_bs = b_scales.shape
1292+
if K_as != K_bs or B_as != B_bs:
1293+
raise ValueError(
1294+
"scaled_matmul requires scales to have matching batch (B) and "
1295+
f"contract (K) dimensions, but got shapes {a_scales.shape} and "
1296+
f"{b_scales.shape}"
1297+
)
1298+
1299+
if M_as != M_a or N_bs != N_b:
1300+
raise ValueError(
1301+
"scaled_matmul requires scales to match non-contract dimensions of "
1302+
f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: "
1303+
f"{a_scales.shape}, b_scales: {b_scales.shape}"
1304+
)
12471305

12481306
preferred_element_type = dtypes.canonicalize_dtype(
12491307
np.dtype(preferred_element_type)
12501308
)
12511309
out = cudnn_scaled_matmul(
1252-
lhs,
1253-
rhs,
1254-
lhs_scales,
1255-
rhs_scales,
1310+
a,
1311+
b,
1312+
a_scales,
1313+
b_scales,
12561314
preferred_element_type=preferred_element_type,
12571315
)
12581316
return out
12591317

1318+
def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'],
1319+
global_scale: Array | None = None):
1320+
r"""Get quantization configs for scaled_dot_general.
1321+
1322+
Create quantization configs for the `jax.nn.scaled_dot_general`.
1323+
1324+
See Also:
1325+
- :func:`jax.nn.scaled_dot_general`: Scaled dot general function.
1326+
"""
1327+
1328+
if mode == 'nvfp4':
1329+
one = jnp.ones((1,), dtype=jnp.float32)
1330+
return BlockScaleConfig(
1331+
mode='nvfp4',
1332+
block_size=16,
1333+
data_type=jnp.float4_e2m1fn,
1334+
scale_type=jnp.float8_e4m3fn,
1335+
global_scale=one if global_scale is None else global_scale,
1336+
infer_only=False
1337+
)
1338+
elif mode == 'mxfp8':
1339+
return BlockScaleConfig(
1340+
mode='mxfp8',
1341+
block_size=32,
1342+
data_type=jnp.float8_e4m3fn,
1343+
scale_type=jnp.float8_e8m0fnu,
1344+
global_scale=None,
1345+
infer_only=False
1346+
)
1347+
else:
1348+
raise ValueError(f"Unsupported mode: {mode}")
1349+
12601350
def scaled_dot_general(
12611351
lhs, rhs,
12621352
dimension_numbers,
12631353
preferred_element_type=jnp.float32,
12641354
configs: List[BlockScaleConfig] | None = None,
1265-
implementation: Literal['cudnn'] | None = None,
12661355
):
12671356
r"""Scaled dot general operation.
1268-
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
1269-
.. math::
1270-
\widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
1271-
\widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
1272-
\mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
1357+
1358+
Performs a generalized dot product with block-scaled quantization on the
1359+
lhs and rhs inputs. This operation extends `lax.dot_general` to support
1360+
user-defined scaling configurations.
1361+
1362+
Essentially, the operation follows::
1363+
1364+
a, a_scales = quantize(lhs, configs[0])
1365+
b, b_scales = quantize(rhs, configs[1])
1366+
c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
1367+
12731368
Args:
1274-
lhs: Left-hand side input tensor.
1275-
rhs: Right-hand side input tensor.
1276-
dimension_numbers: A tuple specifying the contraction and batch dimensions
1277-
for the dot general operation. Must follow the format:
1278-
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
1279-
preferred_element_type: The preferred output data type. Supported types are
1280-
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
1281-
configs: A list of `BlockScaleConfig` specifying the scaling
1282-
configurations for the operation. Defaults to `mxfp8`.
1283-
implementation: A string to control which implementation backend to use.
1284-
Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
1285-
to `None`, which will automatically select the best available backend.
1369+
lhs (ArrayLike): Input array.
1370+
rhs (ArrayLike): Input array.
1371+
dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying
1372+
the contraction and batch dimensions:
1373+
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
1374+
preferred_element_type (DTypeLike, optional): Output data type of the dot
1375+
product. Defaults to `jnp.float32`. Other valid types include
1376+
`jnp.bfloat16` and `jnp.float16`.
1377+
configs (list of BlockScaleConfig, optional): Scaling configurations for
1378+
lhs, rhs, and gradients. Users can obtain valid configurations via
1379+
`jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
1380+
are supported. If `None`, falls back to `lax.dot_general`.
1381+
12861382
Returns:
1287-
The result of the scaled dot general operation.
1383+
Array: The resulting tensor, with batch dimensions first, followed by
1384+
non-contracting/non-batch dimensions of lhs, and then those of rhs.
1385+
1386+
See Also:
1387+
- :func:`jax.nn.scaled_matmul`: Scaled matmul function.
1388+
- :func:`jax.lax.dot_general`: General dot product operator.
1389+
1390+
Notes:
1391+
- Unlike `nn.scaled_matmul`, which assumes quantized low-precision
1392+
inputs with explicit scaling factors, this operator takes high-precision
1393+
inputs, applies quantization internally, and handles the backward pass.
1394+
1395+
Examples:
1396+
1397+
Creating config for mxfp8:
1398+
1399+
>>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3
1400+
1401+
Creating config for nvfp4:
1402+
1403+
>>> global_scale = jnp.array([0.5], jnp.float32)
1404+
>>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3
1405+
1406+
Using scaled_dot_general with the configs:
1407+
1408+
>>> import functools
1409+
>>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
1410+
>>> lhs = random.normal(keys[0], (3, 128, 64))
1411+
>>> rhs = random.normal(keys[1], (3, 128, 64))
1412+
>>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,))))
12881413
"""
12891414
# Create configs if not provided
12901415
if configs is None:
@@ -1300,17 +1425,10 @@ def scaled_dot_general(
13001425
)
13011426
configs = [mxfp8_config for _ in range(3)]
13021427

1303-
if implementation is None:
1304-
implementation = 'cudnn'
1305-
1306-
match implementation:
1307-
case 'cudnn':
1308-
out = cudnn_scaled_dot_general(
1309-
lhs, rhs, dimension_numbers,
1310-
preferred_element_type=preferred_element_type,
1311-
configs=configs
1312-
)
1313-
case _:
1314-
raise ValueError(f"Unsupported implementation option: {implementation}")
1428+
out = cudnn_scaled_dot_general(
1429+
lhs, rhs, dimension_numbers,
1430+
preferred_element_type=preferred_element_type,
1431+
configs=configs
1432+
)
13151433

13161434
return out

jax/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
identity as identity,
3939
relu6 as relu6,
4040
dot_product_attention as dot_product_attention,
41+
get_scaled_dot_general_config as get_scaled_dot_general_config,
4142
scaled_dot_general as scaled_dot_general,
4243
scaled_matmul as scaled_matmul,
4344
selu as selu,

tests/nn_test.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from jax._src.cudnn.scaled_matmul_stablehlo import (
3232
quantize,
3333
shape_normalization,
34-
BlockScaleConfig,
3534
)
3635
from jax.test_util import check_grads
3736
from jax import nn
@@ -110,17 +109,7 @@ def create_mxfp8_configs_if_available():
110109
if _dtypes.float8_e8m0fnu is None:
111110
raise unittest.SkipTest("float8_e8m0fnu is not available.")
112111

113-
def _create_mxfp8_config():
114-
return BlockScaleConfig(
115-
mode='mxfp8',
116-
block_size=32,
117-
data_type=jnp.float8_e4m3fn,
118-
scale_type=jnp.float8_e8m0fnu,
119-
global_scale=None,
120-
infer_only=False
121-
)
122-
123-
return [_create_mxfp8_config() for _ in range(3)]
112+
return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)]
124113

125114

126115
@jtu.with_config(jax_legacy_prng_key="allow",
@@ -130,10 +119,9 @@ class NNFunctionsTest(jtu.JaxTestCase):
130119
contract=[160, 96],
131120
lhs_non_contract=[240, 100],
132121
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
133-
impl=['cudnn',],
134122
)
135-
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
136-
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
123+
def testScaledMatmul(self, contract, lhs_non_contract, dtype):
124+
if not _is_required_cudnn_version_satisfied("10.0", 90700):
137125
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
138126
# Check if float8_e8m0fnu is available
139127
configs = create_mxfp8_configs_if_available()
@@ -153,11 +141,10 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
153141
@parameterized.product(
154142
is_training=[True, False],
155143
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
156-
impl=['cudnn',],
157144
)
158145
def testScaledDotGeneral(
159-
self, is_training, output_type, impl):
160-
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
146+
self, is_training, output_type):
147+
if not _is_required_cudnn_version_satisfied("10.0", 90700):
161148
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
162149

163150
configs = create_mxfp8_configs_if_available()

0 commit comments

Comments
 (0)