55
66import jax
77import jax .numpy as jnp
8+ import jmp
89import optax
910import tensorflow_datasets as tfds
1011from flax import linen as nn
1819
1920
2021class CifarWorkload (BaseCifarWorkload ):
22+ def __init__ (self , * args , ** kwargs ) -> None :
23+ super ().__init__ (* args , ** kwargs )
24+ compute_dtype = spec .JAX_DTYPE_MAP [self ._compute_dtype ]
25+ param_dtype = spec .JAX_DTYPE_MAP [self ._param_dtype ]
26+ output_dtype = compute_dtype
27+ self ._mp_policy = jmp .Policy (
28+ compute_dtype = compute_dtype ,
29+ param_dtype = param_dtype ,
30+ output_dtype = output_dtype ,
31+ )
32+
2133 def _build_cifar_dataset (
2234 self ,
2335 data_rng : spec .RandomState ,
@@ -80,7 +92,8 @@ def sync_batch_stats(
8092 def init_model_fn (self , rng : spec .RandomState ) -> spec .ModelInitState :
8193 """Dropout is unused."""
8294 model_cls = getattr (models , 'ResNet18' )
83- model = model_cls (num_classes = self ._num_classes , dtype = jnp .float32 )
95+ param_dtype = spec .JAX_DTYPE_MAP [self ._param_dtype ]
96+ model = model_cls (num_classes = self ._num_classes , dtype = param_dtype )
8497 self ._model = model
8598 input_shape = (1 , 32 , 32 , 3 )
8699 variables = jax .jit (model .init )(
@@ -89,7 +102,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
89102 model_state , params = pop (variables , 'params' )
90103 self ._param_shapes = param_utils .jax_param_shapes (params )
91104 self ._param_types = param_utils .jax_param_types (self ._param_shapes )
92- model_state = jax_sharding_utils .replicate (params )
105+ model_state = jax_sharding_utils .replicate (model_state )
93106 params = jax_sharding_utils .replicate (params )
94107 return params , model_state
95108
@@ -110,24 +123,32 @@ def model_fn(
110123 del mode
111124 del rng
112125 del dropout_rate
126+ # Cast params and inputs to compute dtype
127+ params , inputs = self ._mp_policy .cast_to_compute (
128+ (params , augmented_and_preprocessed_input_batch ['inputs' ])
129+ )
113130 variables = {'params' : params , ** model_state }
114131 if update_batch_norm :
115132 logits , new_model_state = self ._model .apply (
116133 variables ,
117- augmented_and_preprocessed_input_batch [ ' inputs' ] ,
134+ inputs ,
118135 update_batch_norm = update_batch_norm ,
119136 mutable = ['batch_stats' ],
120137 use_running_average_bn = use_running_average_bn ,
121138 )
139+ # Cast logits to output dtype
140+ logits = self ._mp_policy .cast_to_output (logits )
122141 return logits , new_model_state
123142 else :
124143 logits = self ._model .apply (
125144 variables ,
126- augmented_and_preprocessed_input_batch [ ' inputs' ] ,
145+ inputs ,
127146 update_batch_norm = update_batch_norm ,
128147 mutable = False ,
129148 use_running_average_bn = use_running_average_bn ,
130149 )
150+ # Cast logits to output dtype
151+ logits = self ._mp_policy .cast_to_output (logits )
131152 return logits , model_state
132153
133154 # Does NOT apply regularization, which is left to the submitter to do in
0 commit comments