Skip to content

Commit 49c5388

Browse files
committed
merge
2 parents 4c37f2a + 6f7d638 commit 49c5388

File tree

19 files changed

+286
-88
lines changed

19 files changed

+286
-88
lines changed

algoperf/spec.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,34 @@
66
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
77

88
import jax
9+
import jax.numpy as jnp
10+
import torch
911
import torch.nn.functional as F
1012
from absl import logging
1113
from torch import nn
1214

1315

16+
class DTYPE(enum.Enum):
17+
FLOAT32 = 0
18+
FLOAT16 = 1
19+
BFLOAT16 = 2
20+
21+
22+
# Mapping from DTYPE enum to JAX dtypes
23+
JAX_DTYPE_MAP = {
24+
DTYPE.FLOAT32: jnp.float32,
25+
DTYPE.FLOAT16: jnp.float16,
26+
DTYPE.BFLOAT16: jnp.bfloat16,
27+
}
28+
29+
# Mapping from DTYPE enum to PyTorch dtypes
30+
PYTORCH_DTYPE_MAP = {
31+
DTYPE.FLOAT32: torch.float32,
32+
DTYPE.FLOAT16: torch.float16,
33+
DTYPE.BFLOAT16: torch.bfloat16,
34+
}
35+
36+
1437
class LossType(enum.Enum):
1538
SOFTMAX_CROSS_ENTROPY = 0
1639
SIGMOID_CROSS_ENTROPY = 1

algoperf/workloads/cifar/cifar_jax/input_pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import jax
1212
import tensorflow as tf
1313
import tensorflow_datasets as tfds
14-
from flax import jax_utils
1514

1615
from algoperf import spec
1716
from algoperf.data_utils import shard_and_maybe_pad_np
@@ -186,5 +185,4 @@ def create_input_iter(
186185
),
187186
ds,
188187
)
189-
it = jax_utils.prefetch_to_device(it, 2)
190188
return it

algoperf/workloads/cifar/cifar_jax/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __call__(
3131
update_batch_norm: bool = True,
3232
use_running_average_bn: bool = None,
3333
) -> spec.Tensor:
34-
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
34+
conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype)
3535

3636
# Preserve default behavior for backwards compatibility
3737
if use_running_average_bn is None:
@@ -41,7 +41,7 @@ def __call__(
4141
use_running_average=use_running_average_bn,
4242
momentum=0.9,
4343
epsilon=1e-5,
44-
dtype=self.dtype,
44+
param_dtype=self.dtype,
4545
)
4646

4747
x = conv(
@@ -66,7 +66,9 @@ def __call__(
6666
x = nn.avg_pool(x, (4, 4), strides=(4, 4))
6767
x = jnp.mean(x, axis=(1, 2))
6868
x = nn.Dense(
69-
self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype
69+
self.num_classes,
70+
kernel_init=nn.initializers.normal(),
71+
param_dtype=self.dtype,
7072
)(x)
7173
return x
7274

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import jax
77
import jax.numpy as jnp
8+
import jmp
89
import optax
910
import tensorflow_datasets as tfds
1011
from flax import linen as nn
@@ -18,6 +19,17 @@
1819

1920

2021
class CifarWorkload(BaseCifarWorkload):
22+
def __init__(self, *args, **kwargs) -> None:
23+
super().__init__(*args, **kwargs)
24+
compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype]
25+
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
26+
output_dtype = compute_dtype
27+
self._mp_policy = jmp.Policy(
28+
compute_dtype=compute_dtype,
29+
param_dtype=param_dtype,
30+
output_dtype=output_dtype,
31+
)
32+
2133
def _build_cifar_dataset(
2234
self,
2335
data_rng: spec.RandomState,
@@ -80,7 +92,8 @@ def sync_batch_stats(
8092
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
8193
"""Dropout is unused."""
8294
model_cls = getattr(models, 'ResNet18')
83-
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
95+
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
96+
model = model_cls(num_classes=self._num_classes, dtype=param_dtype)
8497
self._model = model
8598
input_shape = (1, 32, 32, 3)
8699
variables = jax.jit(model.init)(
@@ -89,7 +102,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
89102
model_state, params = pop(variables, 'params')
90103
self._param_shapes = param_utils.jax_param_shapes(params)
91104
self._param_types = param_utils.jax_param_types(self._param_shapes)
92-
model_state = jax_sharding_utils.replicate(params)
105+
model_state = jax_sharding_utils.replicate(model_state)
93106
params = jax_sharding_utils.replicate(params)
94107
return params, model_state
95108

@@ -110,24 +123,32 @@ def model_fn(
110123
del mode
111124
del rng
112125
del dropout_rate
126+
# Cast params and inputs to compute dtype
127+
params, inputs = self._mp_policy.cast_to_compute(
128+
(params, augmented_and_preprocessed_input_batch['inputs'])
129+
)
113130
variables = {'params': params, **model_state}
114131
if update_batch_norm:
115132
logits, new_model_state = self._model.apply(
116133
variables,
117-
augmented_and_preprocessed_input_batch['inputs'],
134+
inputs,
118135
update_batch_norm=update_batch_norm,
119136
mutable=['batch_stats'],
120137
use_running_average_bn=use_running_average_bn,
121138
)
139+
# Cast logits to output dtype
140+
logits = self._mp_policy.cast_to_output(logits)
122141
return logits, new_model_state
123142
else:
124143
logits = self._model.apply(
125144
variables,
126-
augmented_and_preprocessed_input_batch['inputs'],
145+
inputs,
127146
update_batch_norm=update_batch_norm,
128147
mutable=False,
129148
use_running_average_bn=use_running_average_bn,
130149
)
150+
# Cast logits to output dtype
151+
logits = self._mp_policy.cast_to_output(logits)
131152
return logits, model_state
132153

133154
# Does NOT apply regularization, which is left to the submitter to do in

algoperf/workloads/cifar/cifar_pytorch/models.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ def __init__(
2929
width_per_group: int = 64,
3030
replace_stride_with_dilation: Optional[List[bool]] = None,
3131
norm_layer: Optional[Callable[..., nn.Module]] = None,
32+
dtype: torch.dtype = torch.float32,
3233
) -> None:
3334
super().__init__()
3435
if norm_layer is None:
3536
norm_layer = nn.BatchNorm2d
3637
self._norm_layer = norm_layer
38+
self.dtype = dtype
3739

3840
self.inplanes = 64
3941
self.dilation = 1
@@ -49,7 +51,13 @@ def __init__(
4951
self.groups = groups
5052
self.base_width = width_per_group
5153
self.conv1 = nn.Conv2d(
52-
3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
54+
3,
55+
self.inplanes,
56+
kernel_size=3,
57+
stride=1,
58+
padding=1,
59+
bias=False,
60+
dtype=dtype,
5361
)
5462
self.bn1 = norm_layer(self.inplanes)
5563
self.relu = nn.ReLU(inplace=True)
@@ -63,7 +71,7 @@ def __init__(
6371
self.layer4 = self._make_layer(
6472
block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
6573
)
66-
self.fc = nn.Linear(512 * block.expansion, num_classes)
74+
self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype)
6775
self.reset_parameters()
6876

6977
def reset_parameters(self) -> None:
@@ -105,7 +113,15 @@ def _make_layer(
105113
downsample = torch.nn.Sequential(
106114
collections.OrderedDict(
107115
[
108-
('conv', conv1x1(self.inplanes, planes * block.expansion, stride)),
116+
(
117+
'conv',
118+
conv1x1(
119+
self.inplanes,
120+
planes * block.expansion,
121+
stride,
122+
dtype=self.dtype,
123+
),
124+
),
109125
('bn', norm_layer(planes * block.expansion)),
110126
]
111127
)
@@ -122,6 +138,7 @@ def _make_layer(
122138
self.base_width,
123139
previous_dilation,
124140
norm_layer,
141+
dtype=self.dtype,
125142
)
126143
)
127144
self.inplanes = planes * block.expansion
@@ -134,6 +151,7 @@ def _make_layer(
134151
base_width=self.base_width,
135152
dilation=self.dilation,
136153
norm_layer=norm_layer,
154+
dtype=self.dtype,
137155
)
138156
)
139157

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(self, *args, **kwargs) -> None:
2525
# Is set in submission_runner.py for workloads with PyTorch evaluation
2626
# data loaders via the `eval_num_workers` property.
2727
self._eval_num_workers = None
28+
self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype]
29+
self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype]
2830

2931
@property
3032
def eval_num_workers(self) -> int:
@@ -128,7 +130,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
128130
return self._model, None
129131

130132
torch.random.manual_seed(rng[0])
131-
self._model = resnet18(num_classes=self._num_classes)
133+
self._model = resnet18(
134+
num_classes=self._num_classes, dtype=self._param_dtype_pt
135+
)
132136
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
133137
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
134138
self._model.to(DEVICE)
@@ -175,7 +179,8 @@ def model_fn(
175179
spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
176180
}
177181
with contexts[mode]():
178-
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
182+
with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt):
183+
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
179184
return logits_batch, None
180185

181186
# Does NOT apply regularization, which is left to the submitter to do in

algoperf/workloads/cifar/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
class BaseCifarWorkload(spec.Workload):
1818
_num_classes: int = 10
19+
_compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16
20+
_param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32
1921

2022
@property
2123
def target_metric_name(self) -> str:

algoperf/workloads/imagenet_resnet/imagenet_jax/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __call__(
9090
update_batch_norm: bool = True,
9191
use_running_average_bn: Optional[bool] = None,
9292
) -> spec.Tensor:
93-
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
93+
conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype)
9494
# Preserve default behavior for backwards compatibility
9595
if use_running_average_bn is None:
9696
use_running_average_bn = not update_batch_norm
@@ -99,7 +99,7 @@ def __call__(
9999
use_running_average=use_running_average_bn,
100100
momentum=0.9,
101101
epsilon=1e-5,
102-
dtype=self.dtype,
102+
param_dtype=self.dtype,
103103
)
104104

105105
x = conv(
@@ -125,7 +125,9 @@ def __call__(
125125
)(x)
126126
x = jnp.mean(x, axis=(1, 2))
127127
x = nn.Dense(
128-
self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype
128+
self.num_classes,
129+
kernel_init=nn.initializers.normal(),
130+
param_dtype=self.dtype,
129131
)(x)
130132
return x
131133

algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import jax
1313
import jax.numpy as jnp
14+
import jmp
1415
import optax
1516
import tensorflow_datasets as tfds
1617
from flax import linen as nn
@@ -29,6 +30,17 @@
2930

3031

3132
class ImagenetResNetWorkload(BaseImagenetResNetWorkload):
33+
def __init__(self, *args, **kwargs) -> None:
34+
super().__init__(*args, **kwargs)
35+
compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype]
36+
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
37+
output_dtype = compute_dtype
38+
self._mp_policy = jmp.Policy(
39+
compute_dtype=compute_dtype,
40+
param_dtype=param_dtype,
41+
output_dtype=output_dtype,
42+
)
43+
3244
def _build_dataset(
3345
self,
3446
data_rng: spec.RandomState,
@@ -89,11 +101,12 @@ def init_model_fn(
89101
else:
90102
act_fnc = nn.relu
91103

104+
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
92105
model = model_cls(
93106
num_classes=self._num_classes,
94107
act=act_fnc,
95108
bn_init_scale=self.bn_init_scale,
96-
dtype=jnp.float32,
109+
dtype=param_dtype,
97110
)
98111
self._model = model
99112
input_shape = (1, 224, 224, 3)
@@ -159,25 +172,28 @@ def model_fn(
159172
del mode
160173
del rng
161174
del dropout_rate
175+
params, inputs = self._mp_policy.cast_to_compute(
176+
(params, augmented_and_preprocessed_input_batch['inputs'])
177+
)
162178
variables = {'params': params, **model_state}
163179
if update_batch_norm:
164-
logits, new_model_state = self._model.apply(
180+
logits, model_state = self._model.apply(
165181
variables,
166-
augmented_and_preprocessed_input_batch['inputs'],
182+
inputs,
167183
update_batch_norm=update_batch_norm,
168184
mutable=['batch_stats'],
169185
use_running_average_bn=use_running_average_bn,
170186
)
171-
return logits, new_model_state
172187
else:
173188
logits = self._model.apply(
174189
variables,
175-
augmented_and_preprocessed_input_batch['inputs'],
190+
inputs,
176191
update_batch_norm=update_batch_norm,
177192
mutable=False,
178193
use_running_average_bn=use_running_average_bn,
179194
)
180-
return logits, model_state
195+
logits = self._mp_policy.cast_to_output(logits)
196+
return logits, model_state
181197

182198
# Does NOT apply regularization, which is left to the submitter to do in
183199
# `update_params`.

0 commit comments

Comments
 (0)