Skip to content

Commit a4f9948

Browse files
committed
WMT mixed-precision support
1 parent 6f7d638 commit a4f9948

File tree

5 files changed

+72
-23
lines changed

5 files changed

+72
-23
lines changed

algoperf/workloads/wmt/wmt_jax/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def __call__(
364364
input_embed = nn.Embed(
365365
num_embeddings=cfg.vocab_size,
366366
features=cfg.emb_dim,
367+
dtype=cfg.dtype,
367368
embedding_init=nn.initializers.normal(stddev=1.0),
368369
)
369370
else:
@@ -437,6 +438,7 @@ def __call__(
437438
output_embed = nn.Embed(
438439
num_embeddings=cfg.vocab_size,
439440
features=cfg.emb_dim,
441+
dtype=cfg.dtype,
440442
embedding_init=nn.initializers.normal(stddev=1.0),
441443
)
442444
else:
@@ -497,6 +499,7 @@ def setup(self):
497499
self.shared_embedding = nn.Embed(
498500
num_embeddings=cfg.vocab_size,
499501
features=cfg.emb_dim,
502+
dtype=cfg.dtype,
500503
embedding_init=nn.initializers.normal(stddev=1.0),
501504
)
502505
else:

algoperf/workloads/wmt/wmt_jax/workload.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import jax
88
import jax.numpy as jnp
9+
import jmp
910
import numpy as np
1011
import optax
1112
from absl import logging
@@ -27,6 +28,17 @@ def _to_host(x: spec.Tensor) -> spec.Tensor:
2728
class WmtWorkload(BaseWmtWorkload):
2829
"""WMT Jax workload."""
2930

31+
def __init__(self) -> None:
32+
super().__init__()
33+
compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype]
34+
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
35+
output_dtype = compute_dtype
36+
self._mp_policy = jmp.Policy(
37+
compute_dtype=compute_dtype,
38+
param_dtype=param_dtype,
39+
output_dtype=output_dtype,
40+
)
41+
3042
def compute_weighted_cross_entropy(
3143
self,
3244
logits: spec.Tensor,
@@ -251,11 +263,13 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
251263
else:
252264
raise ValueError(f'Unknown activation function {self.activation}.')
253265

266+
param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype]
254267
model_config = models.TransformerConfig(
255268
pre_ln=self.pre_ln,
256269
attention_temp=self.attention_temp,
257270
activation=activation,
258271
glu=self.glu,
272+
dtype=param_dtype,
259273
)
260274
self._train_model = models.Transformer(model_config)
261275
eval_config = replace(model_config, deterministic=True)
@@ -313,6 +327,9 @@ def model_fn(
313327
else:
314328
model = self._eval_model
315329

330+
# Cast params to compute dtype
331+
params = self._mp_policy.cast_to_compute(params)
332+
316333
logits_batch = model.apply(
317334
{'params': params},
318335
inputs,
@@ -324,6 +341,8 @@ def model_fn(
324341
rngs={'dropout': rng},
325342
dropout_rate=dropout_rate,
326343
)
344+
# Cast logits to output dtype
345+
logits_batch = self._mp_policy.cast_to_output(logits_batch)
327346
return logits_batch, None
328347

329348
def _build_input_queue(

algoperf/workloads/wmt/wmt_pytorch/models.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,11 @@ def __init__(
116116
layer_norm_eps: float = 1e-6,
117117
attention_temp: float = 1.0,
118118
pre_ln: bool = True,
119+
dtype: torch.dtype = torch.float32,
119120
):
120121
super().__init__()
121-
self.pos_encoder = PositionalEncoding(d_model)
122-
self.shared_embedding = nn.Embedding(ntoken, d_model)
122+
self.pos_encoder = PositionalEncoding(d_model, dtype=dtype)
123+
self.shared_embedding = nn.Embedding(ntoken, d_model, dtype=dtype)
123124
self.encoder = Encoder(
124125
d_model,
125126
nhead,
@@ -130,6 +131,7 @@ def __init__(
130131
layer_norm_eps,
131132
attention_temp,
132133
pre_ln,
134+
dtype=dtype,
133135
)
134136
self.decoder = Decoder(
135137
d_model,
@@ -141,6 +143,7 @@ def __init__(
141143
layer_norm_eps,
142144
attention_temp,
143145
pre_ln,
146+
dtype=dtype,
144147
)
145148
# Share positional encoding and embedding between encoder and decoder.
146149
self.encoder.pos_encoder = self.pos_encoder
@@ -287,6 +290,7 @@ def __init__(
287290
layer_norm_eps: float = 1e-6,
288291
attention_temp: float = 1.0,
289292
pre_ln: bool = True,
293+
dtype: torch.dtype = torch.float32,
290294
):
291295
super().__init__()
292296
self.nhead = nhead
@@ -301,8 +305,11 @@ def __init__(
301305
layer_norm_eps=layer_norm_eps,
302306
attention_temp=attention_temp,
303307
pre_ln=pre_ln,
308+
dtype=dtype,
309+
)
310+
encoder_norm = (
311+
nn.LayerNorm(d_model, eps=layer_norm_eps, dtype=dtype) if pre_ln else None
304312
)
305-
encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None
306313
self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm)
307314

308315
def forward(
@@ -332,6 +339,7 @@ def __init__(
332339
layer_norm_eps: float = 1e-6,
333340
attention_temp: float = 1.0,
334341
pre_ln: bool = True,
342+
dtype: torch.dtype = torch.float32,
335343
):
336344
super().__init__()
337345
self.nhead = nhead
@@ -347,6 +355,7 @@ def __init__(
347355
nlayers,
348356
attention_temp,
349357
pre_ln,
358+
dtype=dtype,
350359
)
351360

352361
def forward(
@@ -398,13 +407,18 @@ def forward(
398407

399408

400409
class PositionalEncoding(nn.Module):
401-
def __init__(self, d_model: int, max_len: int = 256):
410+
def __init__(
411+
self,
412+
d_model: int,
413+
max_len: int = 256,
414+
dtype: torch.dtype = torch.float32,
415+
):
402416
super().__init__()
403417

404418
position = torch.arange(max_len).unsqueeze(1)
405419
scale_factor = -math.log(10000.0) / (d_model // 2 - 1)
406420
div_term = torch.exp(torch.arange(d_model // 2) * scale_factor)
407-
pe = torch.zeros(1, max_len, d_model)
421+
pe = torch.zeros(1, max_len, d_model, dtype=dtype)
408422
pe[0, :, : d_model // 2] = torch.sin(position * div_term)
409423
pe[0, :, d_model // 2 : 2 * (d_model // 2)] = torch.cos(position * div_term)
410424
self.register_buffer('pe', pe)
@@ -599,6 +613,7 @@ def __init__(
599613
num_layers,
600614
attention_temp,
601615
pre_ln,
616+
dtype: torch.dtype = torch.float32,
602617
):
603618
super().__init__()
604619
self.layers = nn.ModuleList(
@@ -612,12 +627,15 @@ def __init__(
612627
layer_norm_eps=layer_norm_eps,
613628
attention_temp=attention_temp,
614629
pre_ln=pre_ln,
630+
dtype=dtype,
615631
)
616632
for _ in range(num_layers)
617633
]
618634
)
619635
self.num_layers = num_layers
620-
self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None
636+
self.norm = (
637+
nn.LayerNorm(d_model, eps=layer_norm_eps, dtype=dtype) if pre_ln else None
638+
)
621639

622640
def forward(
623641
self,

algoperf/workloads/wmt/wmt_pytorch/workload.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
class WmtWorkload(BaseWmtWorkload):
2525
"""WMT PyTorch workload."""
2626

27+
def __init__(self) -> None:
28+
super().__init__()
29+
self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype]
30+
self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype]
31+
2732
def compute_weighted_cross_entropy(
2833
self,
2934
logits: spec.Tensor,
@@ -189,6 +194,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
189194
attention_temp=self.attention_temp,
190195
activation=activation,
191196
glu=self.glu,
197+
dtype=self._param_dtype_pt,
192198
)
193199
self._param_shapes = param_utils.pytorch_param_shapes(model)
194200
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
@@ -228,23 +234,24 @@ def model_fn(
228234
}
229235

230236
with contexts[mode]():
231-
logits_batch = model(
232-
src=augmented_and_preprocessed_input_batch['inputs'],
233-
tgt=augmented_and_preprocessed_input_batch['targets'],
234-
inputs_positions=augmented_and_preprocessed_input_batch.get(
235-
'inputs_position', None
236-
),
237-
targets_positions=augmented_and_preprocessed_input_batch.get(
238-
'targets_position', None
239-
),
240-
inputs_segmentation=augmented_and_preprocessed_input_batch.get(
241-
'inputs_segmentation', None
242-
),
243-
targets_segmentation=augmented_and_preprocessed_input_batch.get(
244-
'targets_segmentation', None
245-
),
246-
dropout_rate=dropout_rate,
247-
)
237+
with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt):
238+
logits_batch = model(
239+
src=augmented_and_preprocessed_input_batch['inputs'],
240+
tgt=augmented_and_preprocessed_input_batch['targets'],
241+
inputs_positions=augmented_and_preprocessed_input_batch.get(
242+
'inputs_position', None
243+
),
244+
targets_positions=augmented_and_preprocessed_input_batch.get(
245+
'targets_position', None
246+
),
247+
inputs_segmentation=augmented_and_preprocessed_input_batch.get(
248+
'inputs_segmentation', None
249+
),
250+
targets_segmentation=augmented_and_preprocessed_input_batch.get(
251+
'targets_segmentation', None
252+
),
253+
dropout_rate=dropout_rate,
254+
)
248255

249256
return logits_batch, None
250257

algoperf/workloads/wmt/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class BaseWmtWorkload(spec.Workload):
2222
"""A WMT workload."""
2323

2424
_vocab_size: int = 32000
25+
_compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16
26+
_param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32
2527

2628
def __init__(self) -> None:
2729
super().__init__()

0 commit comments

Comments
 (0)