Skip to content

Commit 8b5262c

Browse files
committed
Removed debug prints
1 parent 84f3765 commit 8b5262c

File tree

1 file changed

+0
-7
lines changed

1 file changed

+0
-7
lines changed

mpx/experimental/fp8.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,10 @@ def quantized_multiplication_fwd(a: ArrayLike, b: ArrayLike, dtype8, dimension_n
4343

4444
# f_bwd :: (c, CT b) -> CT a
4545
def quantized_multiplication_bwd(dtype8, dimension_numbers, precision, preferred_element_type, out_sharding, c, dy_dc):
46-
print("--------")
47-
print(c)
48-
print(dy_dc)
4946
a_q, b_q, scaling_a, scaling_b = c
5047
# backward is performed in fp32 TODO allow to change it.
5148
a = a_q.astype(jnp.float32) / scaling_a
5249
b = b_q.astype(jnp.float32) / scaling_b
53-
print(dy_dc.shape)
54-
print(b.shape)
5550
dy_da = jax.lax.dot_general_p.bind(dy_dc, b.T, dimension_numbers=dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=out_sharding)
5651
dy_db = jax.lax.dot_general_p.bind(a.T, dy_dc, dimension_numbers=dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, out_sharding=out_sharding)
5752

@@ -61,8 +56,6 @@ def quantized_multiplication_bwd(dtype8, dimension_numbers, precision, preferred
6156

6257
@quax.register(jax.lax.dot_general_p)
6358
def _(lhs: ArrayLike, rhs: ArrayLike, **params):
64-
print("Performing a matmul!")
65-
print(params)
6659
return quantized_multiplication(lhs, rhs, jnp.float8_e4m3, **params)
6760

6861
if __name__ == "__main__":

0 commit comments

Comments
 (0)