Skip to content

Commit 7579f80

Browse files
committed
add box embedding model
1 parent d52b422 commit 7579f80

File tree

3 files changed

+179
-57
lines changed

3 files changed

+179
-57
lines changed

chebai/models/electra.py

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,18 @@ def filter_dict(d: Dict[str, Any], filter_key: str) -> Dict[str, Any]:
161161
}
162162

163163

164-
class Electra(ChebaiBaseNet):
165-
"""
166-
Electra model implementation inherited from ChebaiBaseNet.
164+
class ElectraProcessingMixIn:
165+
"""Mixin class for processing batches and outputs for Electra models."""
167166

168-
Args:
169-
config (Dict[str, Any], optional): Configuration parameters for the Electra model. Defaults to None.
170-
pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None.
171-
load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
172-
**kwargs: Additional keyword arguments.
167+
@property
168+
def as_pretrained(self) -> ElectraModel:
169+
"""
170+
Get the pretrained Electra model.
173171
174-
"""
172+
Returns:
173+
ElectraModel: The pretrained Electra model.
174+
"""
175+
return self.electra.electra
175176

176177
def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
177178
"""
@@ -209,15 +210,61 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any
209210
idents=batch.additional_fields["idents"],
210211
)
211212

212-
@property
213-
def as_pretrained(self) -> ElectraModel:
213+
def _process_for_loss(
214+
self,
215+
model_output: Dict[str, Tensor],
216+
labels: Tensor,
217+
loss_kwargs: Dict[str, Any],
218+
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
214219
"""
215-
Get the pretrained Electra model.
220+
Process the model output for calculating the loss.
221+
222+
Args:
223+
model_output (Dict[str, Tensor]): The output of the model.
224+
labels (Tensor): The target labels.
225+
loss_kwargs (Dict[str, Any]): Additional loss arguments.
216226
217227
Returns:
218-
ElectraModel: The pretrained Electra model.
228+
tuple: A tuple containing the processed model output, labels, and loss arguments.
219229
"""
220-
return self.electra.electra
230+
kwargs_copy = dict(loss_kwargs)
231+
if labels is not None:
232+
labels = labels.float()
233+
return model_output["logits"], labels, kwargs_copy
234+
235+
def _get_prediction_and_labels(
236+
self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor]
237+
) -> Tuple[Tensor, Tensor]:
238+
"""
239+
Get the predictions and labels from the model output. Applies a sigmoid to the model output.
240+
241+
Args:
242+
data (Dict[str, Any]): The input data.
243+
labels (Tensor): The target labels.
244+
model_output (Dict[str, Tensor]): The output of the model.
245+
246+
Returns:
247+
tuple: A tuple containing the predictions and labels.
248+
"""
249+
d = model_output["logits"]
250+
loss_kwargs = data.get("loss_kwargs", dict())
251+
if "non_null_labels" in loss_kwargs:
252+
n = loss_kwargs["non_null_labels"]
253+
d = d[n]
254+
return torch.sigmoid(d), labels.int() if labels is not None else None
255+
256+
257+
class Electra(ElectraProcessingMixIn, ChebaiBaseNet):
258+
"""
259+
Electra model implementation inherited from ChebaiBaseNet.
260+
261+
Args:
262+
config (Dict[str, Any], optional): Configuration parameters for the Electra model. Defaults to None.
263+
pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None.
264+
load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
265+
**kwargs: Additional keyword arguments.
266+
267+
"""
221268

222269
def __init__(
223270
self,
@@ -262,49 +309,6 @@ def __init__(
262309
else:
263310
self.electra = ElectraModel(config=self.config)
264311

265-
def _process_for_loss(
266-
self,
267-
model_output: Dict[str, Tensor],
268-
labels: Tensor,
269-
loss_kwargs: Dict[str, Any],
270-
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
271-
"""
272-
Process the model output for calculating the loss.
273-
274-
Args:
275-
model_output (Dict[str, Tensor]): The output of the model.
276-
labels (Tensor): The target labels.
277-
loss_kwargs (Dict[str, Any]): Additional loss arguments.
278-
279-
Returns:
280-
tuple: A tuple containing the processed model output, labels, and loss arguments.
281-
"""
282-
kwargs_copy = dict(loss_kwargs)
283-
if labels is not None:
284-
labels = labels.float()
285-
return model_output["logits"], labels, kwargs_copy
286-
287-
def _get_prediction_and_labels(
288-
self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor]
289-
) -> Tuple[Tensor, Tensor]:
290-
"""
291-
Get the predictions and labels from the model output. Applies a sigmoid to the model output.
292-
293-
Args:
294-
data (Dict[str, Any]): The input data.
295-
labels (Tensor): The target labels.
296-
model_output (Dict[str, Tensor]): The output of the model.
297-
298-
Returns:
299-
tuple: A tuple containing the predictions and labels.
300-
"""
301-
d = model_output["logits"]
302-
loss_kwargs = data.get("loss_kwargs", dict())
303-
if "non_null_labels" in loss_kwargs:
304-
n = loss_kwargs["non_null_labels"]
305-
d = d[n]
306-
return torch.sigmoid(d), labels.int() if labels is not None else None
307-
308312
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
309313
"""
310314
Forward pass of the Electra model.

chebai/models/electra_box.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import ElectraConfig, ElectraModel
4+
5+
from chebai.models.base import ChebaiBaseNet
6+
from chebai.models.electra import ElectraProcessingMixIn, filter_dict
7+
8+
9+
class ElectraBox(ElectraProcessingMixIn, ChebaiBaseNet):
10+
NAME = "ElectraBox"
11+
12+
def __init__(
13+
self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs
14+
):
15+
super().__init__(**kwargs)
16+
if config is None:
17+
config = dict()
18+
if "num_labels" not in config and self.out_dim is not None:
19+
config["num_labels"] = self.out_dim
20+
self.config = ElectraConfig(**config, output_attentions=True)
21+
self.word_dropout = nn.Dropout(config.get("word_dropout", 0))
22+
23+
self.in_dim = self.config.hidden_size
24+
self.hidden_dim = self.config.embeddings_to_points_hidden_size
25+
self.out_dim = self.config.embeddings_dimensions
26+
self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)))
27+
self.embeddings_to_points = nn.Sequential(
28+
nn.Linear(self.in_dim, self.hidden_dim),
29+
nn.ReLU(),
30+
nn.Linear(self.hidden_dim, self.hidden_dim),
31+
nn.ReLU(),
32+
nn.Dropout(0.1),
33+
nn.Linear(self.hidden_dim, self.out_dim),
34+
)
35+
36+
if pretrained_checkpoint:
37+
with open(pretrained_checkpoint, "rb") as fin:
38+
model_dict = torch.load(fin, map_location=self.device)
39+
if load_prefix:
40+
state_dict = filter_dict(model_dict["state_dict"], load_prefix)
41+
else:
42+
state_dict = model_dict["state_dict"]
43+
self.electra = ElectraModel.from_pretrained(
44+
None, state_dict=state_dict, config=self.config
45+
)
46+
else:
47+
self.electra = ElectraModel(config=self.config)
48+
49+
def forward(self, data, **kwargs):
50+
self.batch_size = data["features"].shape[0]
51+
inp = self.electra.embeddings.forward(data["features"])
52+
inp = self.word_dropout(inp)
53+
electra = self.electra(inputs_embeds=inp)
54+
d = electra.last_hidden_state[:, 0, :]
55+
56+
points = self.embeddings_to_points(d)
57+
58+
b = self.boxes.expand(self.batch_size, -1, -1, -1)
59+
raw_l = torch.min(b, dim=-1)[0]
60+
raw_r = torch.max(b, dim=-1)[0]
61+
62+
left = raw_l + ((raw_r - raw_l) * 0.2)
63+
right = raw_r - ((raw_r - raw_l) * 0.2)
64+
65+
p = points.expand(self.config.num_labels, -1, -1).transpose(1, 0)
66+
max_distance_per_dim = torch.max(
67+
torch.stack((nn.functional.relu(left - p), nn.functional.relu(p - right))),
68+
dim=0,
69+
)[0]
70+
71+
m = torch.sum(max_distance_per_dim, dim=-1)
72+
s = 2 - (2 * (torch.sigmoid(m)))
73+
logits = torch.logit((s * 0.99) + 0.001)
74+
75+
return dict(
76+
boxes=b,
77+
embedded_points=points,
78+
logits=logits,
79+
attentions=electra.attentions,
80+
target_mask=data.get("target_mask"),
81+
)
82+
83+
84+
if __name__ == "__main__":
85+
model = ElectraBox(
86+
config={
87+
"vocab_size": 4400,
88+
"max_position_embeddings": 1800,
89+
"num_attention_heads": 8,
90+
"num_hidden_layers": 6,
91+
"type_vocab_size": 1,
92+
"hidden_size": 256,
93+
"embeddings_to_points_hidden_size": 1200,
94+
"embeddings_dimensions": 16,
95+
},
96+
out_dim=120,
97+
input_dim=1800,
98+
)
99+
import torch
100+
101+
print(
102+
model._process_for_loss(
103+
torch.randint(0, 4400, (2, 1800)), torch.randint(0, 2, (2, 120))
104+
)
105+
)

configs/model/box-electra.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class_path: chebai.models.electra_box.ElectraBox
2+
init_args:
3+
optimizer_kwargs:
4+
lr: 1e-3
5+
config:
6+
vocab_size: 4400
7+
max_position_embeddings: 1800
8+
num_attention_heads: 8
9+
num_hidden_layers: 6
10+
type_vocab_size: 1
11+
hidden_size: 256
12+
embeddings_to_points_hidden_size: 1200
13+
embeddings_dimensions: 16

0 commit comments

Comments
 (0)