Skip to content

Commit 870de14

Browse files
author
Alexander
committed
casting functions now only cast float arrays
1 parent dae03f2 commit 870de14

File tree

6 files changed

+514
-9
lines changed

6 files changed

+514
-9
lines changed

examples/Bert.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""
2+
This is a port of the Bert example from Equinox (https://docs.kidger.site/equinox/examples/bert/).
3+
"""
4+
5+
import functools
6+
from collections.abc import Mapping
7+
8+
import einops # https://github.com/arogozhnikov/einops
9+
import equinox as eqx
10+
import jax
11+
import jax.numpy as jnp
12+
import numpy as np
13+
import optax # https://github.com/deepmind/optax
14+
from datasets import load_dataset # https://github.com/huggingface/datasets
15+
from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping
16+
from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm
17+
from transformers import AutoTokenizer # https://github.com/huggingface/transformers
18+
19+
from examples.transformer import TransformerLayer
20+
21+
class EmbedderBlock(eqx.Module):
22+
"""BERT embedder."""
23+
24+
token_embedder: eqx.nn.Embedding
25+
segment_embedder: eqx.nn.Embedding
26+
position_embedder: eqx.nn.Embedding
27+
layernorm: eqx.nn.LayerNorm
28+
dropout: eqx.nn.Dropout
29+
30+
def __init__(
31+
self,
32+
vocab_size: int,
33+
max_length: int,
34+
type_vocab_size: int,
35+
embedding_size: int,
36+
hidden_size: int,
37+
dropout_rate: float,
38+
key: jax.random.PRNGKey,
39+
):
40+
token_key, segment_key, position_key = jax.random.split(key, 3)
41+
42+
self.token_embedder = eqx.nn.Embedding(
43+
num_embeddings=vocab_size, embedding_size=embedding_size, key=token_key
44+
)
45+
self.segment_embedder = eqx.nn.Embedding(
46+
num_embeddings=type_vocab_size,
47+
embedding_size=embedding_size,
48+
key=segment_key,
49+
)
50+
self.position_embedder = eqx.nn.Embedding(
51+
num_embeddings=max_length, embedding_size=embedding_size, key=position_key
52+
)
53+
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
54+
self.dropout = eqx.nn.Dropout(dropout_rate)
55+
56+
def __call__(
57+
self,
58+
token_ids: Int[Array, " seq_len"],
59+
position_ids: Int[Array, " seq_len"],
60+
segment_ids: Int[Array, " seq_len"],
61+
enable_dropout: bool = False,
62+
key: jax.random.PRNGKey | None = None,
63+
) -> Float[Array, "seq_len hidden_size"]:
64+
tokens = jax.vmap(self.token_embedder)(token_ids)
65+
segments = jax.vmap(self.segment_embedder)(segment_ids)
66+
positions = jax.vmap(self.position_embedder)(position_ids)
67+
embedded_inputs = tokens + segments + positions
68+
embedded_inputs = jax.vmap(self.layernorm)(embedded_inputs)
69+
embedded_inputs = self.dropout(
70+
embedded_inputs, inference=not enable_dropout, key=key
71+
)
72+
return embedded_inputs
73+
74+
75+
class Encoder(eqx.Module):
76+
"""Full BERT encoder."""
77+
78+
embedder_block: EmbedderBlock
79+
layers: list[TransformerLayer]
80+
pooler: eqx.nn.Linear
81+
82+
def __init__(
83+
self,
84+
vocab_size: int,
85+
max_length: int,
86+
type_vocab_size: int,
87+
embedding_size: int,
88+
hidden_size: int,
89+
intermediate_size: int,
90+
num_layers: int,
91+
num_heads: int,
92+
dropout_rate: float,
93+
attention_dropout_rate: float,
94+
key: jax.random.PRNGKey,
95+
):
96+
embedder_key, layer_key, pooler_key = jax.random.split(key, num=3)
97+
self.embedder_block = EmbedderBlock(
98+
vocab_size=vocab_size,
99+
max_length=max_length,
100+
type_vocab_size=type_vocab_size,
101+
embedding_size=embedding_size,
102+
hidden_size=hidden_size,
103+
dropout_rate=dropout_rate,
104+
key=embedder_key,
105+
)
106+
107+
layer_keys = jax.random.split(layer_key, num=num_layers)
108+
self.layers = []
109+
for layer_key in layer_keys:
110+
self.layers.append(
111+
TransformerLayer(
112+
hidden_size=hidden_size,
113+
intermediate_size=intermediate_size,
114+
num_heads=num_heads,
115+
dropout_rate=dropout_rate,
116+
attention_dropout_rate=attention_dropout_rate,
117+
key=layer_key,
118+
)
119+
)
120+
121+
self.pooler = eqx.nn.Linear(
122+
in_features=hidden_size, out_features=hidden_size, key=pooler_key
123+
)
124+
125+
def __call__(
126+
self,
127+
token_ids: Int[Array, " seq_len"],
128+
position_ids: Int[Array, " seq_len"],
129+
segment_ids: Int[Array, " seq_len"],
130+
*,
131+
enable_dropout: bool = False,
132+
key: jax.random.PRNGKey | None = None,
133+
) -> dict[str, Array]:
134+
emb_key, l_key = (None, None) if key is None else jax.random.split(key)
135+
136+
embeddings = self.embedder_block(
137+
token_ids=token_ids,
138+
position_ids=position_ids,
139+
segment_ids=segment_ids,
140+
enable_dropout=enable_dropout,
141+
key=emb_key,
142+
)
143+
144+
# We assume that all 0-values should be masked out.
145+
mask = jnp.asarray(token_ids != 0, dtype=jnp.int32)
146+
147+
x = embeddings
148+
layer_outputs = []
149+
for layer in self.layers:
150+
cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key)
151+
x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key)
152+
layer_outputs.append(x)
153+
154+
# BERT pooling.
155+
# The first token in the last layer is the embedding of the "[CLS]" token.
156+
first_token_last_layer = x[..., 0, :]
157+
pooled = self.pooler(first_token_last_layer)
158+
pooled = jnp.tanh(pooled)
159+
160+
return {"embeddings": embeddings, "layers": layer_outputs, "pooled": pooled}
161+
162+
163+
class BertClassifier(eqx.Module):
164+
"""BERT classifier."""
165+
166+
encoder: Encoder
167+
classifier_head: eqx.nn.Linear
168+
dropout: eqx.nn.Dropout
169+
170+
def __init__(self, config: Mapping, num_classes: int, key: jax.random.PRNGKey):
171+
encoder_key, head_key = jax.random.split(key)
172+
173+
self.encoder = Encoder(
174+
vocab_size=config["vocab_size"],
175+
max_length=config["max_position_embeddings"],
176+
type_vocab_size=config["type_vocab_size"],
177+
embedding_size=config["hidden_size"],
178+
hidden_size=config["hidden_size"],
179+
intermediate_size=config["intermediate_size"],
180+
num_layers=config["num_hidden_layers"],
181+
num_heads=config["num_attention_heads"],
182+
dropout_rate=config["hidden_dropout_prob"],
183+
attention_dropout_rate=config["attention_probs_dropout_prob"],
184+
key=encoder_key,
185+
)
186+
self.classifier_head = eqx.nn.Linear(
187+
in_features=config["hidden_size"], out_features=num_classes, key=head_key
188+
)
189+
self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"])
190+
191+
def __call__(
192+
self,
193+
inputs: dict[str, Int[Array, " seq_len"]],
194+
enable_dropout: bool = True,
195+
key: jax.random.PRNGKey = None,
196+
) -> Float[Array, " num_classes"]:
197+
seq_len = inputs["token_ids"].shape[-1]
198+
position_ids = jnp.arange(seq_len)
199+
200+
e_key, d_key = (None, None) if key is None else jax.random.split(key)
201+
202+
pooled_output = self.encoder(
203+
token_ids=inputs["token_ids"],
204+
segment_ids=inputs["segment_ids"],
205+
position_ids=position_ids,
206+
enable_dropout=enable_dropout,
207+
key=e_key,
208+
)["pooled"]
209+
pooled_output = self.dropout(
210+
pooled_output, inference=not enable_dropout, key=d_key
211+
)
212+
213+
return self.classifier_head(pooled_output)
214+
215+
if __name__ == "__main__":
216+
# Tiny-BERT config.
217+
bert_config = {
218+
"vocab_size": 30522,
219+
"hidden_size": 128,
220+
"num_hidden_layers": 2,
221+
"num_attention_heads": 2,
222+
"hidden_act": "gelu",
223+
"intermediate_size": 512,
224+
"hidden_dropout_prob": 0.1,
225+
"attention_probs_dropout_prob": 0.1,
226+
"max_position_embeddings": 512,
227+
"type_vocab_size": 2,
228+
"initializer_range": 0.02,
229+
}
230+
231+
key = jax.random.PRNGKey(5678)
232+
model_key, train_key = jax.random.split(key)
233+
classifier = BertClassifier(config=bert_config, num_classes=2, key=model_key)
234+
235+
tokenizer = AutoTokenizer.from_pretrained(
236+
"google/bert_uncased_L-2_H-128_A-2", model_max_length=128
237+
)
238+
239+
def tokenize(example):
240+
return tokenizer(example["sentence"], padding="max_length", truncation=True)
241+
242+
ds = load_dataset("sst2")
243+
ds = ds.map(tokenize, batched=True)
244+
ds.set_format(type="jax", columns=["input_ids", "token_type_ids", "label"])
245+

0 commit comments

Comments
 (0)