-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathworkload.py
More file actions
223 lines (206 loc) · 7.47 KB
/
workload.py
File metadata and controls
223 lines (206 loc) · 7.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""CIFAR workload implemented in Jax."""
import functools
from typing import Any, Dict, Iterator, Optional, Tuple
import jax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.core import pop
from jax import lax
from algoperf import jax_sharding_utils, param_utils, spec
from algoperf.workloads.cifar.cifar_jax import models
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
from algoperf.workloads.cifar.workload import BaseCifarWorkload
class CifarWorkload(BaseCifarWorkload):
def _build_cifar_dataset(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
ds_builder.download_and_prepare()
train = split == 'train'
assert self.num_train_examples + self.num_validation_examples == 50000
if split in ['train', 'eval_train']:
split = f'train[:{self.num_train_examples}]'
elif split == 'validation':
split = f'train[{self.num_train_examples}:]'
ds = create_input_iter(
split,
ds_builder,
data_rng,
batch_size,
self.train_mean,
self.train_stddev,
self.crop_size,
self.padding_size,
train=train,
cache=not train if cache is None else cache,
repeat_final_dataset=repeat_final_dataset,
)
return ds
def _build_input_queue(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
return self._build_cifar_dataset(
data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset
)
def sync_batch_stats(
self, model_state: spec.ModelAuxiliaryState
) -> spec.ModelAuxiliaryState:
"""Sync the batch statistics across replicas."""
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics
# and we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy()
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Dropout is unused."""
model_cls = getattr(models, 'ResNet18')
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)(
{'params': rng}, jnp.ones(input_shape, model.dtype)
)
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_sharding_utils.replicate(params)
params = jax_sharding_utils.replicate(params)
return params, model_state
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
def model_fn(
self,
params: spec.ParameterContainer,
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None,
dropout_rate: float = 0.0,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
del dropout_rate
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn,
)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False,
use_running_average_bn=use_running_average_bn,
)
return logits, model_state
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
self,
label_batch: spec.Tensor, # Dense or one-hot labels.
logits_batch: spec.Tensor,
mask_batch: Optional[spec.Tensor] = None,
label_smoothing: float = 0.0,
) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
per_example_losses = -jnp.sum(
smoothed_targets * nn.log_softmax(logits_batch), axis=-1
)
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
n_valid_examples = mask_batch.sum()
else:
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
'summed': summed_loss,
'n_valid_examples': n_valid_examples,
'per_example': per_example_losses,
}
def _compute_metrics(
self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor
) -> Dict[str, spec.Tensor]:
summed_loss = self.loss_fn(labels, logits, weights)['summed']
# Number of correct predictions.
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
metrics = {
'loss': summed_loss,
'accuracy': accuracy,
}
return metrics
def _eval_model(
self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
@functools.partial(
jax.jit,
in_shardings=(
jax_sharding_utils.get_replicate_sharding(), # params
jax_sharding_utils.get_batch_dim_sharding(), # batch
jax_sharding_utils.get_replicate_sharding(), # model_state
jax_sharding_utils.get_batch_dim_sharding(), # rng
),
)
def _eval_model_jitted(
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False,
)
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)
metrics = _eval_model_jitted(params, batch, model_state, rng)
return jax.tree.map(lambda x: x.item(), metrics)
def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str, Any]
) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree.map(lambda x: x / num_examples, total_metrics)