@@ -43,15 +43,10 @@ def quantized_multiplication_fwd(a: ArrayLike, b: ArrayLike, dtype8, dimension_n
4343
4444# f_bwd :: (c, CT b) -> CT a
4545def quantized_multiplication_bwd (dtype8 , dimension_numbers , precision , preferred_element_type , out_sharding , c , dy_dc ):
46- print ("--------" )
47- print (c )
48- print (dy_dc )
4946 a_q , b_q , scaling_a , scaling_b = c
5047 # backward is performed in fp32 TODO allow to change it.
5148 a = a_q .astype (jnp .float32 ) / scaling_a
5249 b = b_q .astype (jnp .float32 ) / scaling_b
53- print (dy_dc .shape )
54- print (b .shape )
5550 dy_da = jax .lax .dot_general_p .bind (dy_dc , b .T , dimension_numbers = dimension_numbers , precision = precision , preferred_element_type = preferred_element_type , out_sharding = out_sharding )
5651 dy_db = jax .lax .dot_general_p .bind (a .T , dy_dc , dimension_numbers = dimension_numbers , precision = precision , preferred_element_type = preferred_element_type , out_sharding = out_sharding )
5752
@@ -61,8 +56,6 @@ def quantized_multiplication_bwd(dtype8, dimension_numbers, precision, preferred
6156
6257@quax .register (jax .lax .dot_general_p )
6358def _ (lhs : ArrayLike , rhs : ArrayLike , ** params ):
64- print ("Performing a matmul!" )
65- print (params )
6659 return quantized_multiplication (lhs , rhs , jnp .float8_e4m3 , ** params )
6760
6861if __name__ == "__main__" :
0 commit comments