Skip to content

Commit 5ddec65

Browse files
committed
Remove asserts
1 parent f949b8b commit 5ddec65

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

jax/_src/nn/functions.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,14 +1273,35 @@ def scaled_matmul(
12731273
>>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu)
12741274
>>> scaled_matmul(a, b, a_scales, b_scales)
12751275
"""
1276-
assert all(x.ndim == 3 for x in (a, b, a_scales, b_scales))
1276+
if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)):
1277+
raise ValueError(
1278+
"scaled_matmul requires all inputs to be 3-dimensional arrays"
1279+
)
1280+
12771281
B_a, M_a, K_a = a.shape
12781282
B_b, N_b, K_b = b.shape
1279-
assert K_a == K_b and B_a == B_b
1283+
if K_a != K_b or B_a != B_b:
1284+
raise ValueError(
1285+
"scaled_matmul requires inputs a and b to have matching batch (B) "
1286+
f"and contract (K) dimensions, but got shapes {a.shape} and "
1287+
f"{b.shape}"
1288+
)
1289+
12801290
B_as, M_as, K_as = a_scales.shape
12811291
B_bs, N_bs, K_bs = b_scales.shape
1282-
assert K_as == K_bs and B_as == B_bs
1283-
assert M_as == M_a and N_bs == N_b
1292+
if K_as != K_bs or B_as != B_bs:
1293+
raise ValueError(
1294+
"scaled_matmul requires scales to have matching batch (B) and "
1295+
f"contract (K) dimensions, but got shapes {a_scales.shape} and "
1296+
f"{b_scales.shape}"
1297+
)
1298+
1299+
if M_as != M_a or N_bs != N_b:
1300+
raise ValueError(
1301+
"scaled_matmul requires scales to match non-contract dimensions of "
1302+
f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: "
1303+
f"{a_scales.shape}, b_scales: {b_scales.shape}"
1304+
)
12841305

12851306
preferred_element_type = dtypes.canonicalize_dtype(
12861307
np.dtype(preferred_element_type)

0 commit comments

Comments
 (0)