1313from jaxtyping import Array , Float , Int , PyTree , PRNGKeyArray , ArrayLike
1414
1515from .._dtypes import forward_datatype , backward_datatype
16- from .._cast import cast_tree
16+ from .._cast import cast_function
1717
1818
1919def max_val (dtype ):
2020 return (jnp .finfo (dtype ).max ).astype (jnp .float32 )
2121
22- @partial (jax .custom_vjp , nondiff_argnames = ("dtype8" , 'dimension_numbers' , 'precision' , 'preferred_element_type' , 'out_sharding' ))
23- def quantized_multiplication (a : ArrayLike , b : ArrayLike , dtype8 , dimension_numbers , precision , preferred_element_type , out_sharding ):
22+ @partial (jax .custom_vjp , nondiff_argnames = ('dimension_numbers' , 'precision' , 'preferred_element_type' , 'out_sharding' ))
23+ def quantized_multiplication (a : ArrayLike , b : ArrayLike , dimension_numbers , precision , preferred_element_type , out_sharding ):
2424 a_max = jnp .max (jnp .abs (a ))
2525 b_max = jnp .max (jnp .abs (b ))
26- max_dtype = max_val (dtype8 )
26+ fwd_dtype = forward_datatype ()
27+ max_dtype = max_val (fwd_dtype )
2728 scaling_a = max_dtype / (a_max + 1e-8 )
2829 scaling_b = max_dtype / (b_max + 1e-8 )
2930
30- a_q = (a * scaling_a ).astype (dtype8 )
31- b_q = (b * scaling_b ).astype (dtype8 )
31+ a_q = (a * scaling_a ).astype (fwd_dtype )
32+ b_q = (b * scaling_b ).astype (fwd_dtype )
3233
3334 result_q = jax .lax .dot_general_p .bind (a_q , b_q , dimension_numbers = dimension_numbers , precision = precision , preferred_element_type = preferred_element_type , out_sharding = out_sharding )
34-
35- result = (result_q .astype (jnp .float32 )) / (scaling_a * scaling_b )
35+ result = (result_q .astype (backward_datatype ())) / (scaling_a * scaling_b )
3636 return result
3737
3838
39- def quantized_multiplication_fwd (a : ArrayLike , b : ArrayLike , dtype8 , dimension_numbers , precision , preferred_element_type , out_sharding ):
39+ def quantized_multiplication_fwd (a : ArrayLike , b : ArrayLike , dimension_numbers , precision , preferred_element_type , out_sharding ):
4040 a_max = jnp .max (jnp .abs (a ))
4141 b_max = jnp .max (jnp .abs (b ))
42- max_dtype = max_val (dtype8 )
42+ fwd_dtype = forward_datatype ()
43+ max_dtype = max_val (fwd_dtype )
4344 scaling_a = max_dtype / (a_max + 1e-8 )
4445 scaling_b = max_dtype / (b_max + 1e-8 )
4546
46- a_q = (a * scaling_a ).astype (dtype8 )
47- b_q = (b * scaling_b ).astype (dtype8 )
47+ a_q = (a * scaling_a ).astype (fwd_dtype )
48+ b_q = (b * scaling_b ).astype (fwd_dtype )
4849 # we want to save the quantized versions for the backward pass to save memory
49- return quantized_multiplication (a , b , dtype8 , dimension_numbers , precision , preferred_element_type , out_sharding ), (a_q , b_q , scaling_a , scaling_b )
50+ return quantized_multiplication (a , b , dimension_numbers , precision , preferred_element_type , out_sharding ), (a_q , b_q , scaling_a , scaling_b )
5051
5152# f_bwd :: (c, CT b) -> CT a
52- def quantized_multiplication_bwd (dtype8 , dimension_numbers , precision , preferred_element_type , out_sharding , c , dy_dc ):
53+ def quantized_multiplication_bwd (dimension_numbers , precision , preferred_element_type , out_sharding , c , dy_dc ):
5354 a_q , b_q , scaling_a , scaling_b = c
5455 backward_dtype = backward_datatype ()
5556 # backward is performed in fp32 TODO allow to change it.
@@ -62,41 +63,24 @@ def quantized_multiplication_bwd(dtype8, dimension_numbers, precision, preferred
6263
6364quantized_multiplication .defvjp (quantized_multiplication_fwd , quantized_multiplication_bwd )
6465
66+
6567@quax .register (jax .lax .dot_general_p )
6668def _ (lhs : ArrayLike , rhs : ArrayLike , ** params ):
67- return quantized_multiplication (lhs , rhs , jnp . float8_e4m3 , ** params )
69+ return quantized_multiplication (lhs , rhs , ** params )
6870
6971
70-
71- def cast_function (func , dtype , return_dtype = None ):
72+ def cast_function_fwd_bwd (f : callable ) -> callable :
7273 """
73- Casts the function to the specified data type.
74+ Casts a function to use the specified forward and backward data types.
75+ Args:
76+ f (callable): The function to be cast.
77+ Returns:
78+ callable: A new function that uses the specified data types for forward and backward passes.
7479 """
7580
76- if return_dtype is None :
77- return_dtype = dtype
78-
79- def wrapper (* args , ** kwargs ):
80- args_cast = []
81- for arg in args :
82- args_cast .append (cast_tree (arg , dtype ))
83- args_cast = tuple (args_cast )
84-
85- kwargs_cast = {}
86- for key , value in kwargs .items ():
87- kwargs_cast [key ] = cast_tree (value , dtype )
88-
89- results = func (* args_cast , ** kwargs_cast )
90-
91- if type (results ) == tuple :
92- results_converted = []
93- for r in results :
94- results_converted .append (cast_tree (r , return_dtype ))
95- return tuple (results_converted )
96- elif eqx .is_array (results ):
97- return cast_tree (results , return_dtype )
98- return results
99-
100- return wrapper
81+ # cast inuts to bwd_dtype. This makes all non multiply operations to be in bwd_dtype
82+ f = cast_function (f , backward_datatype ())
10183
84+ f = quax .quaxify (f )
10285
86+ return f
0 commit comments