@@ -1273,14 +1273,35 @@ def scaled_matmul(
12731273 >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
12741274 >>> scaled_matmul(a, b, a_scales, b_scales)
12751275 """
1276- assert all (x .ndim == 3 for x in (a , b , a_scales , b_scales ))
1276+ if not all (x .ndim == 3 for x in (a , b , a_scales , b_scales )):
1277+ raise ValueError (
1278+ "scaled_matmul requires all inputs to be 3-dimensional arrays"
1279+ )
1280+
12771281 B_a , M_a , K_a = a .shape
12781282 B_b , N_b , K_b = b .shape
1279- assert K_a == K_b and B_a == B_b
1283+ if K_a != K_b or B_a != B_b :
1284+ raise ValueError (
1285+ "scaled_matmul requires inputs a and b to have matching batch (B) "
1286+ f"and contract (K) dimensions, but got shapes { a .shape } and "
1287+ f"{ b .shape } "
1288+ )
1289+
12801290 B_as , M_as , K_as = a_scales .shape
12811291 B_bs , N_bs , K_bs = b_scales .shape
1282- assert K_as == K_bs and B_as == B_bs
1283- assert M_as == M_a and N_bs == N_b
1292+ if K_as != K_bs or B_as != B_bs :
1293+ raise ValueError (
1294+ "scaled_matmul requires scales to have matching batch (B) and "
1295+ f"contract (K) dimensions, but got shapes { a_scales .shape } and "
1296+ f"{ b_scales .shape } "
1297+ )
1298+
1299+ if M_as != M_a or N_bs != N_b :
1300+ raise ValueError (
1301+ "scaled_matmul requires scales to match non-contract dimensions of "
1302+ f"inputs, but got shapes a: { a .shape } , b: { b .shape } , a_scales: "
1303+ f"{ a_scales .shape } , b_scales: { b_scales .shape } "
1304+ )
12841305
12851306 preferred_element_type = dtypes .canonicalize_dtype (
12861307 np .dtype (preferred_element_type )
0 commit comments