Skip to content

Commit 429ad30

Browse files
author
Alexander
committed
temp
1 parent 4634480 commit 429ad30

File tree

4 files changed

+418
-53
lines changed

4 files changed

+418
-53
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ These functions are just for your information. They are internally used, however
6161

6262
### Gradient Computation
6363
`mpx` provides function decorators for gradient calculations that summarize steps 3--9 in one function call. They have the same meaning and syntax as the corresponding decorators of `equinox`. This means, for an existing training pipeline, one can replace the calls of `equinox.filter_grad/filter_value_and_grad` with `mpx.filter_grad/filter_value_and_grad`
64-
- `filter_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False, use_mixed_precision=True)`: Transformation that computes the gradient of func with respect to its first argument using mixed precision with scaling, similar to `equinox.filter_grad`. The decorator works as follows:
64+
- `filter_grad(func, scaling: loss_scaling.DynamicLossScaling, has_aux=False, use_mixed_precision=True)`: Transformation that computes the gradient of func with respect to its first argument using mixed precision with scaling, similar to `equinox.filter_grad`. The transformed function then works as follows:
6565
1. If `use_mixed_precision` is True:
6666
- Casts all input arguments to half precision (float16/bfloat16)
6767
- Scales the function's output by `scaling`

examples/Bert.py

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@
1616
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
1717
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
1818

19+
import einshape as es
20+
1921
from examples.transformer import TransformerLayer
2022

23+
import mpx
24+
2125
class EmbedderBlock(eqx.Module):
2226
"""BERT embedder."""
2327

@@ -59,7 +63,7 @@ def __call__(
5963
position_ids: Array,
6064
segment_ids: Array,
6165
enable_dropout: bool = False,
62-
key: jax.random.PRNGKey | None = None,
66+
key: jax.random.PRNGKey = None,
6367
) -> Array:
6468
tokens = jax.vmap(self.token_embedder)(token_ids)
6569
segments = jax.vmap(self.segment_embedder)(segment_ids)
@@ -129,7 +133,7 @@ def __call__(
129133
segment_ids: Array,
130134
*,
131135
enable_dropout: bool = False,
132-
key: jax.random.PRNGKey | None = None,
136+
key: jax.random.PRNGKey = None,
133137
) -> dict[str, Array]:
134138
emb_key, l_key = (None, None) if key is None else jax.random.split(key)
135139

@@ -216,21 +220,20 @@ def compute_loss(classifier, inputs, key):
216220
batch_size = inputs["token_ids"].shape[0]
217221
batched_keys = jax.random.split(key, num=batch_size)
218222
logits = jax.vmap(classifier, in_axes=(0, None, 0))(inputs, True, batched_keys)
219-
return jnp.mean(
220-
optax.softmax_cross_entropy_with_integer_labels(
223+
# all of these operations are done in full precision
224+
return mpx.force_full_precision(jnp.mean)(
225+
mpx.force_full_precision(optax.softmax_cross_entropy_with_integer_labels)(
221226
logits=logits, labels=inputs["label"]
222227
)
223228
)
224229

225230

226-
def make_step(model, inputs, opt_state, key, tx):
231+
def make_step(model, inputs, opt_state, key, tx, scaling: mpx.DynamicLossScaling):
227232
key, new_key = jax.random.split(key)
228-
loss, grads = compute_loss(model, inputs, key)
229-
grads = jax.lax.pmean(grads, axis_name="devices")
230-
231-
updates, opt_state = tx.update(grads, opt_state, model)
232-
model = eqx.apply_updates(model, updates)
233-
return loss, model, opt_state, new_key
233+
loss, scaling, grads_finite, grads = mpx.filter_value_and_grad(compute_loss, scaling)(model, inputs, key)
234+
235+
model, opt_state = mpx.optimizer_update(model, tx, opt_state, grads, grads_finite)
236+
return loss, model, opt_state, new_key, scaling
234237

235238

236239
def make_eval_step(model, inputs):
@@ -239,6 +242,7 @@ def make_eval_step(model, inputs):
239242
if __name__ == "__main__":
240243
# Tiny-BERT config.
241244
bert_config = {
245+
"train_mixed_precision": True,
242246
"vocab_size": 30522,
243247
"hidden_size": 128,
244248
"num_hidden_layers": 2,
@@ -271,36 +275,76 @@ def tokenize(example):
271275
batch_size = 32
272276
learning_rate = 1e-5
273277

278+
############################
279+
# init model
280+
############################
281+
model = BertClassifier(config=bert_config, num_classes=2, key=model_key)
282+
283+
############################
284+
# init optimizer
285+
############################
286+
tx = optax.adam(learning_rate=learning_rate)
287+
tx = optax.chain(optax.clip_by_global_norm(1.0), tx)
288+
opt_state = tx.init(model)
289+
290+
############################
291+
# init scaling
292+
############################
293+
if bert_config["train_mixed_precision"]:
294+
loss_scaling = mpx.DynamicLossScaling(loss_scaling=jnp.ones((1,), dtype=jnp.float32) * int((2 - 2**(-10)) * 2**15),
295+
min_loss_scaling=jnp.ones((1,), dtype=jnp.float32) * 1.0)
296+
else:
297+
loss_scaling = None
298+
299+
############################
300+
# training
301+
############################
274302
for epoch in range(epochs):
275-
with tqdm.tqdm(
276-
ds["train"].iter(batch_size=batch_size, drop_last_batch=True),
277-
total=ds["train"].num_rows // batch_size,
278-
unit="steps",
279-
desc=f"Epoch {epoch+1}/{epochs}",
280-
) as tqdm_epoch:
281-
for batch in tqdm_epoch:
303+
with tqdm.tqdm(
304+
ds["train"].iter(batch_size=batch_size, drop_last_batch=True),
305+
total=ds["train"].num_rows // batch_size,
306+
unit="steps",
307+
desc=f"Epoch {epoch+1}/{epochs}",
308+
) as tqdm_epoch:
309+
310+
for batch in tqdm_epoch:
311+
token_ids, token_type_ids = batch["input_ids"], batch["token_type_ids"]
312+
label = batch["label"]
313+
314+
# swap time and feature axis.
315+
token_ids = es.jax_einshape("bhn->bnh", token_ids)
316+
token_type_ids = es.jax_einshape("bhn->bnh", token_type_ids)
317+
318+
inputs = {
319+
"token_ids": token_ids,
320+
"segment_ids": token_type_ids,
321+
"label": label,
322+
}
323+
loss, model, opt_state, train_key, loss_scaling = make_step(
324+
model, inputs, opt_state, train_key, tx, scaling=loss_scaling
325+
)
326+
327+
tqdm_epoch.set_postfix(loss=np.sum(loss).item())
328+
329+
outputs = []
330+
for batch in tqdm.tqdm(
331+
ds["validation"].iter(batch_size=batch_size),
332+
unit="steps",
333+
total=np.ceil(ds["validation"].num_rows / batch_size),
334+
desc="Validation",
335+
):
282336
token_ids, token_type_ids = batch["input_ids"], batch["token_type_ids"]
283337
label = batch["label"]
284338

285-
# Split batch across devices.
286-
token_ids = einops.rearrange(
287-
token_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
288-
)
289-
token_type_ids = einops.rearrange(
290-
token_type_ids, "(b1 b2) s -> b1 b2 s", b1=num_devices
291-
)
292-
label = einops.rearrange(label, "(b1 b2) -> b1 b2", b1=num_devices)
293-
294-
inputs = {
295-
"token_ids": token_ids,
296-
"segment_ids": token_type_ids,
297-
"label": label,
298-
}
299-
loss, model, opt_state, train_key = p_make_step(
300-
model, inputs, opt_state, train_key
301-
)
302339

303-
tqdm_epoch.set_postfix(loss=np.sum(loss).item())
340+
inputs = {"token_ids": token_ids, "segment_ids": token_type_ids}
341+
342+
# Compare predicted class with label.
343+
output = make_eval_step(model, inputs)
344+
output = map(float, np.argmax(output.reshape(-1, 2), axis=-1) == label)
345+
outputs.extend(output)
346+
347+
print(f"Accuracy: {100 * np.sum(outputs) / len(outputs):.2f}%")
304348

305349

306350

0 commit comments

Comments
 (0)