@@ -1210,81 +1210,206 @@ def _check_shape_and_dtype(t: Array | None, shape: Sequence[int],
12101210 return jnp .reshape (out , output_shape )
12111211
12121212def scaled_matmul (
1213- lhs : Array ,
1214- rhs : Array ,
1215- lhs_scales : Array ,
1216- rhs_scales : Array ,
1213+ a : Array ,
1214+ b : Array ,
1215+ a_scales : Array ,
1216+ b_scales : Array ,
12171217 preferred_element_type : DTypeLike = jnp .float32 ,
12181218) -> Array :
1219- r"""
1220- Performs scaled matrix multiplication between two 3D arrays, with scaling
1221- factors applied to the matrices.
1222- .. math::
1223- \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
1219+ r"""Scaled matrix multiplication function.
1220+
1221+ Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`.
1222+ The last dim is the contracting dim, and block size is inferred.
1223+
1224+ Mathematically, this operation is equivalent to::
1225+
1226+ a_block_size = a.shape[-1] // a_scales.shape[-1]
1227+ b_block_size = b.shape[-1] // b_scales.shape[-1]
1228+ a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1)
1229+ b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1)
1230+ jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled)
1231+
12241232 Args:
1225- lhs (Array): A 3D array of shape (B, M, K).
1226- rhs (Array): A 3D array of shape (B, N, K).
1227- lhs_scales (Array): A 3D array of shape (B, M, K_block) .
1228- rhs_scales (Array): A 3D array of shape (B, N, K_block) .
1229- preferred_element_type (DTypeLike, optional): The preferred data type
1230- for the computation. Defaults to `jnp.float32`.
1233+ a (Array): Shape (B, M, K).
1234+ b (Array): Shape (B, N, K).
1235+ a_scales (Array): Shape (B, M, K_a), where `K % K_a == 0` .
1236+ b_scales (Array): Shape (B, N, K_b), where `K % K_b == 0` .
1237+ preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`.
1238+
12311239 Returns:
1232- Array: A 3D array of shape (B, M, N) representing the scaled matrix
1233- multiplication result.
1234- Raises:
1235- AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
1236- match the number of columns in `rhs` (`rhs_K`).
1240+ Array of shape (B, M, N).
1241+
12371242 Notes:
1238- - The function ensures that the `preferred_element_type` is
1239- danonicalized before passing it to the underlying computation.
1240- - Scaling is applied to the matrices based on the `lhs_scales` and
1241- `rhs_scales` arrays, enabling efficient computations in blocks.
1243+ - We currently do not support user-defined `precision` for customizing the
1244+ compute data type. It is fixed to `jnp.float32`.
1245+ - Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`.
1246+ - To use cuDNN with Nvidia Blackwell GPUs, inputs must match::
1247+
1248+ # mxfp8
1249+ a, b: jnp.float8_e4m3fn | jnp.float8_e5m2
1250+ a_scales, b_scales: jnp.float8_e8m0fnu
1251+ block_size: 32
1252+ # nvfp4
1253+ a, b: jnp.float4_e2m1fn
1254+ a_scales, b_scales: jnp.float8_e4m3fn
1255+ block_size: 16
1256+
1257+ Examples:
1258+
1259+ Basic case:
1260+
1261+ >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3))
1262+ >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3))
1263+ >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1))
1264+ >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1))
1265+ >>> scaled_matmul(a, b, a_scales, b_scales)
1266+ Array([[[8.]]], dtype=float32)
1267+
1268+ Using fused cuDNN call on Blackwell GPUs:
1269+
1270+ >>> a = random.normal(keys[0], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1271+ >>> b = random.normal(keys[1], (3, 128, 64), dtype=jnp.float8_e4m3fn)
1272+ >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1273+ >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
1274+ >>> scaled_matmul(a, b, a_scales, b_scales)
12421275 """
1243- B , M , lhs_K = lhs .shape
1244- _ , N , rhs_K = rhs .shape
1245- assert lhs_K == rhs_K
1246- _ , _ , K_block = lhs_scales .shape
1276+ if not all (x .ndim == 3 for x in (a , b , a_scales , b_scales )):
1277+ raise ValueError (
1278+ "scaled_matmul requires all inputs to be 3-dimensional arrays"
1279+ )
1280+
1281+ B_a , M_a , K_a = a .shape
1282+ B_b , N_b , K_b = b .shape
1283+ if K_a != K_b or B_a != B_b :
1284+ raise ValueError (
1285+ "scaled_matmul requires inputs a and b to have matching batch (B) "
1286+ f"and contract (K) dimensions, but got shapes { a .shape } and "
1287+ f"{ b .shape } "
1288+ )
1289+
1290+ B_as , M_as , K_as = a_scales .shape
1291+ B_bs , N_bs , K_bs = b_scales .shape
1292+ if K_as != K_bs or B_as != B_bs :
1293+ raise ValueError (
1294+ "scaled_matmul requires scales to have matching batch (B) and "
1295+ f"contract (K) dimensions, but got shapes { a_scales .shape } and "
1296+ f"{ b_scales .shape } "
1297+ )
1298+
1299+ if M_as != M_a or N_bs != N_b :
1300+ raise ValueError (
1301+ "scaled_matmul requires scales to match non-contract dimensions of "
1302+ f"inputs, but got shapes a: { a .shape } , b: { b .shape } , a_scales: "
1303+ f"{ a_scales .shape } , b_scales: { b_scales .shape } "
1304+ )
12471305
12481306 preferred_element_type = dtypes .canonicalize_dtype (
12491307 np .dtype (preferred_element_type )
12501308 )
12511309 out = cudnn_scaled_matmul (
1252- lhs ,
1253- rhs ,
1254- lhs_scales ,
1255- rhs_scales ,
1310+ a ,
1311+ b ,
1312+ a_scales ,
1313+ b_scales ,
12561314 preferred_element_type = preferred_element_type ,
12571315 )
12581316 return out
12591317
1318+ def get_scaled_dot_general_config (mode : Literal ['nvfp4' , 'mxfp8' ],
1319+ global_scale : Array | None = None ):
1320+ r"""Get quantization configs for scaled_dot_general.
1321+
1322+ Create quantization configs for the `jax.nn.scaled_dot_general`.
1323+
1324+ See Also:
1325+ - :func:`jax.nn.scaled_dot_general`: Scaled dot general function.
1326+ """
1327+
1328+ if mode == 'nvfp4' :
1329+ one = jnp .ones ((1 ,), dtype = jnp .float32 )
1330+ return BlockScaleConfig (
1331+ mode = 'nvfp4' ,
1332+ block_size = 16 ,
1333+ data_type = jnp .float4_e2m1fn ,
1334+ scale_type = jnp .float8_e4m3fn ,
1335+ global_scale = one if global_scale is None else global_scale ,
1336+ infer_only = False
1337+ )
1338+ elif mode == 'mxfp8' :
1339+ return BlockScaleConfig (
1340+ mode = 'mxfp8' ,
1341+ block_size = 32 ,
1342+ data_type = jnp .float8_e4m3fn ,
1343+ scale_type = jnp .float8_e8m0fnu ,
1344+ global_scale = None ,
1345+ infer_only = False
1346+ )
1347+ else :
1348+ raise ValueError (f"Unsupported mode: { mode } " )
1349+
12601350def scaled_dot_general (
12611351 lhs , rhs ,
12621352 dimension_numbers ,
12631353 preferred_element_type = jnp .float32 ,
12641354 configs : List [BlockScaleConfig ] | None = None ,
1265- implementation : Literal ['cudnn' ] | None = None ,
12661355 ):
12671356 r"""Scaled dot general operation.
1268- Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
1269- .. math::
1270- \widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
1271- \widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
1272- \mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
1357+
1358+ Performs a generalized dot product with block-scaled quantization on the
1359+ lhs and rhs inputs. This operation extends `lax.dot_general` to support
1360+ user-defined scaling configurations.
1361+
1362+ Essentially, the operation follows::
1363+
1364+ a, a_scales = quantize(lhs, configs[0])
1365+ b, b_scales = quantize(rhs, configs[1])
1366+ c = jax.nn.scaled_matmul(a, b, a_scales, b_scales)
1367+
12731368 Args:
1274- lhs: Left-hand side input tensor.
1275- rhs: Right-hand side input tensor.
1276- dimension_numbers: A tuple specifying the contraction and batch dimensions
1277- for the dot general operation. Must follow the format:
1278- `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
1279- preferred_element_type: The preferred output data type. Supported types are
1280- `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
1281- configs: A list of `BlockScaleConfig` specifying the scaling
1282- configurations for the operation. Defaults to `mxfp8`.
1283- implementation: A string to control which implementation backend to use.
1284- Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
1285- to `None`, which will automatically select the best available backend.
1369+ lhs (ArrayLike): Input array.
1370+ rhs (ArrayLike): Input array.
1371+ dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying
1372+ the contraction and batch dimensions:
1373+ `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
1374+ preferred_element_type (DTypeLike, optional): Output data type of the dot
1375+ product. Defaults to `jnp.float32`. Other valid types include
1376+ `jnp.bfloat16` and `jnp.float16`.
1377+ configs (list of BlockScaleConfig, optional): Scaling configurations for
1378+ lhs, rhs, and gradients. Users can obtain valid configurations via
1379+ `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8`
1380+ are supported. If `None`, falls back to `lax.dot_general`.
1381+
12861382 Returns:
1287- The result of the scaled dot general operation.
1383+ Array: The resulting tensor, with batch dimensions first, followed by
1384+ non-contracting/non-batch dimensions of lhs, and then those of rhs.
1385+
1386+ See Also:
1387+ - :func:`jax.nn.scaled_matmul`: Scaled matmul function.
1388+ - :func:`jax.lax.dot_general`: General dot product operator.
1389+
1390+ Notes:
1391+ - Unlike `nn.scaled_matmul`, which assumes quantized low-precision
1392+ inputs with explicit scaling factors, this operator takes high-precision
1393+ inputs, applies quantization internally, and handles the backward pass.
1394+
1395+ Examples:
1396+
1397+ Creating config for mxfp8:
1398+
1399+ >>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3
1400+
1401+ Creating config for nvfp4:
1402+
1403+ >>> global_scale = jnp.array([0.5], jnp.float32)
1404+ >>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3
1405+
1406+ Using scaled_dot_general with the configs:
1407+
1408+ >>> import functools
1409+ >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs)
1410+ >>> lhs = random.normal(keys[0], (3, 128, 64))
1411+ >>> rhs = random.normal(keys[1], (3, 128, 64))
1412+ >>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,))))
12881413 """
12891414 # Create configs if not provided
12901415 if configs is None :
@@ -1300,17 +1425,10 @@ def scaled_dot_general(
13001425 )
13011426 configs = [mxfp8_config for _ in range (3 )]
13021427
1303- if implementation is None :
1304- implementation = 'cudnn'
1305-
1306- match implementation :
1307- case 'cudnn' :
1308- out = cudnn_scaled_dot_general (
1309- lhs , rhs , dimension_numbers ,
1310- preferred_element_type = preferred_element_type ,
1311- configs = configs
1312- )
1313- case _:
1314- raise ValueError (f"Unsupported implementation option: { implementation } " )
1428+ out = cudnn_scaled_dot_general (
1429+ lhs , rhs , dimension_numbers ,
1430+ preferred_element_type = preferred_element_type ,
1431+ configs = configs
1432+ )
13151433
13161434 return out
0 commit comments