Skip to content

Commit 839d88a

Browse files
committed
fix: reviewer comments
1 parent 8df3aaf commit 839d88a

File tree

2 files changed

+61
-63
lines changed

2 files changed

+61
-63
lines changed

model2vec/train/base.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,25 +108,6 @@ def device(self) -> str:
108108
"""Get the device of the model."""
109109
return self.embeddings.weight.device
110110

111-
def to_static_model(self, config: dict[str, Any] | None = None) -> StaticModel:
112-
"""
113-
Convert the model to a static model.
114-
115-
This is useful if you want to discard your head, and consolidate the information learned by
116-
the model to use it in a downstream task.
117-
118-
:param config: The config used in the StaticModel. If this is set to None, it will have no config.
119-
:return: A static model.
120-
"""
121-
# Perform the forward pass on the selected device.
122-
with torch.no_grad():
123-
all_indices = torch.arange(len(self.embeddings.weight))[:, None].to(self.device)
124-
vectors = self._encode(all_indices).cpu().numpy()
125-
126-
new_model = StaticModel(vectors=vectors, tokenizer=self.tokenizer, config=config)
127-
128-
return new_model
129-
130111

131112
class TextDataset(Dataset):
132113
def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:

model2vec/train/classifier.py

Lines changed: 61 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -88,79 +88,96 @@ def fit(
8888
self,
8989
X: list[str],
9090
y: list[str],
91-
**kwargs: Any,
91+
learning_rate: float = 1e-3,
92+
batch_size: int = 32,
93+
early_stopping_patience: int | None = 25,
94+
test_size: float = 0.1,
9295
) -> ClassificationStaticModel:
9396
"""Fit a model."""
9497
pl.seed_everything(42)
95-
classes = sorted(set(y))
96-
self.classes_ = classes
97-
98-
if len(self.classes) != self.out_dim:
99-
self.out_dim = len(self.classes)
100-
101-
self.head = self.construct_head()
102-
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
103-
104-
label_mapping = {label: idx for idx, label in enumerate(self.classes)}
105-
label_counts = Counter(y)
106-
if min(label_counts.values()) < 2:
107-
logger.info("Some classes have less than 2 samples. Stratification is disabled.")
108-
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
109-
X, y, test_size=0.1, random_state=42, shuffle=True
110-
)
111-
else:
112-
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
113-
X, y, test_size=0.1, random_state=42, shuffle=True, stratify=y
114-
)
98+
self._initialize(y)
11599

116-
# Turn labels into a LongTensor
117-
train_tokenized: list[list[int]] = [
118-
encoding.ids for encoding in self.tokenizer.encode_batch_fast(train_texts, add_special_tokens=False)
119-
]
120-
train_labels_tensor = torch.Tensor([label_mapping[label] for label in train_labels]).long()
121-
train_dataset = TextDataset(train_tokenized, train_labels_tensor)
100+
train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
101+
X, y, test_size=test_size
102+
)
122103

123-
val_tokenized: list[list[int]] = [
124-
encoding.ids for encoding in self.tokenizer.encode_batch_fast(validation_texts, add_special_tokens=False)
125-
]
126-
val_labels_tensor = torch.Tensor([label_mapping[label] for label in validation_labels]).long()
127-
val_dataset = TextDataset(val_tokenized, val_labels_tensor)
104+
train_dataset = self._prepare_dataset(train_texts, train_labels)
105+
val_dataset = self._prepare_dataset(validation_texts, validation_labels)
128106

129-
c = ClassifierLightningModule(self)
107+
c = ClassifierLightningModule(self, learning_rate=learning_rate)
130108

131-
batch_size = 32
132109
n_train_batches = len(train_dataset) // batch_size
133-
callbacks: list[Callback] = [EarlyStopping(monitor="val_accuracy", mode="max", patience=5)]
110+
callbacks: list[Callback] = []
111+
if early_stopping_patience is not None:
112+
callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience)
113+
callbacks.append(callback)
114+
134115
if n_train_batches < 250:
135-
trainer = pl.Trainer(max_epochs=500, callbacks=callbacks, check_val_every_n_epoch=1)
116+
val_check_interval = None
117+
check_val_every_epoch = True
136118
else:
137119
val_check_interval = max(250, 2 * len(val_dataset) // batch_size)
138-
trainer = pl.Trainer(
139-
max_epochs=500, callbacks=callbacks, val_check_interval=val_check_interval, check_val_every_n_epoch=None
140-
)
120+
check_val_every_epoch = False
121+
trainer = pl.Trainer(
122+
max_epochs=500,
123+
callbacks=callbacks,
124+
val_check_interval=val_check_interval,
125+
check_val_every_n_epoch=check_val_every_epoch,
126+
)
141127

142128
trainer.fit(
143129
c,
144130
train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size),
145131
val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size),
146132
)
147133
best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore
134+
best_model_weights = torch.load(best_model_path, weights_only=True)
148135

149-
state_dict = {
150-
k.removeprefix("model."): v for k, v in torch.load(best_model_path, weights_only=True)["state_dict"].items()
151-
}
152-
self.load_state_dict(state_dict)
136+
state_dict = {}
137+
for weight_name, weight in best_model_weights["state_dict"].items():
138+
state_dict[weight_name.removeprefix("model.")] = weight
153139

140+
self.load_state_dict(state_dict)
154141
self.eval()
155142

156143
return self
157144

145+
def _initialize(self, y: list[str]) -> None:
146+
"""Sets the out dimensionality, the classes and initializes the head."""
147+
classes = sorted(set(y))
148+
self.classes_ = classes
149+
150+
if len(self.classes) != self.out_dim:
151+
self.out_dim = len(self.classes)
152+
153+
self.head = self.construct_head()
154+
self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id)
155+
156+
def _prepare_dataset(self, X: list[str], y: list[str]) -> TextDataset:
157+
"""Prepare a dataset."""
158+
tokenized: list[list[int]] = [
159+
encoding.ids for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False)
160+
]
161+
labels_tensor = torch.Tensor([self.classes.index(label) for label in y]).long()
162+
return TextDataset(tokenized, labels_tensor)
163+
164+
def _train_test_split(
165+
self, X: list[str], y: list[str], test_size: float
166+
) -> tuple[list[str], list[str], list[str], list[str]]:
167+
"""Split the data."""
168+
label_counts = Counter(y)
169+
if min(label_counts.values()) < 2:
170+
logger.info("Some classes have less than 2 samples. Stratification is disabled.")
171+
return train_test_split(X, y, test_size=0.1, random_state=42, shuffle=True)
172+
return train_test_split(X, y, test_size=0.1, random_state=42, shuffle=True, stratify=y)
173+
158174

159175
class ClassifierLightningModule(pl.LightningModule):
160-
def __init__(self, model: ClassificationStaticModel) -> None:
176+
def __init__(self, model: ClassificationStaticModel, learning_rate: float) -> None:
161177
"""Initialize the lightningmodule."""
162178
super().__init__()
163179
self.model = model
180+
self.learning_rate = learning_rate
164181

165182
def forward(self, x: torch.Tensor) -> torch.Tensor:
166183
"""Simple forward pass."""

0 commit comments

Comments
 (0)