Skip to content

Commit a285a70

Browse files
authored
Resolve compatibility issues with Datasets & Sentence Transformers (#614)
1 parent e52eb3c commit a285a70

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

src/setfit/modeling.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def fit(
284284
if not end_to_end:
285285
self.freeze("body")
286286

287-
dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length)
287+
dataloader = self._prepare_dataloader(list(x_train), list(y_train), batch_size, max_length)
288288
criterion = self.model_head.get_loss_fn()
289289
optimizer = self._prepare_optimizer(head_learning_rate, body_learning_rate, l2_weight)
290290
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
@@ -314,8 +314,8 @@ def fit(
314314
if not end_to_end:
315315
self.unfreeze("body")
316316
else: # train with sklearn
317-
embeddings = self.model_body.encode(x_train, normalize_embeddings=self.normalize_embeddings)
318-
self.model_head.fit(embeddings, y_train)
317+
embeddings = self.model_body.encode(list(x_train), normalize_embeddings=self.normalize_embeddings)
318+
self.model_head.fit(embeddings, list(y_train))
319319
if self.labels is None and self.multi_target_strategy is None:
320320
# Try to set the labels based on the head classes, if they exist
321321
# This can fail in various ways, so we catch all exceptions
@@ -477,6 +477,7 @@ def _output_type_conversion(
477477
outputs = torch.from_numpy(outputs)
478478
return outputs
479479

480+
@torch.no_grad()
480481
def predict_proba(
481482
self,
482483
inputs: Union[str, List[str]],
@@ -521,6 +522,7 @@ def predict_proba(
521522
outputs = self._output_type_conversion(probs, as_numpy=as_numpy)
522523
return outputs[0] if is_singular else outputs
523524

525+
@torch.no_grad()
524526
def predict(
525527
self,
526528
inputs: Union[str, List[str]],
@@ -556,7 +558,7 @@ def predict(
556558
is_singular = isinstance(inputs, str)
557559
if is_singular:
558560
inputs = [inputs]
559-
embeddings = self.encode(inputs, batch_size=batch_size, show_progress_bar=show_progress_bar)
561+
embeddings = self.encode(list(inputs), batch_size=batch_size, show_progress_bar=show_progress_bar)
560562
preds = self.model_head.predict(embeddings)
561563
# If labels are defined, we don't have multilabels & the output is not already strings, then we convert to string labels
562564
if (

src/setfit/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,7 @@ def train_classifier(
655655
end_to_end=args.end_to_end,
656656
)
657657

658+
@torch.no_grad()
658659
def evaluate(self, dataset: Optional[Dataset] = None, metric_key_prefix: str = "test") -> Dict[str, float]:
659660
"""
660661
Computes the metrics for a given classifier.

src/setfit/trainer_distillation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def get_dataset(
8585
max_pairs: int = -1,
8686
) -> Tuple[DataLoader, nn.Module, int, int]:
8787
x_embd_student = self.teacher_model.model_body.encode(
88-
x, convert_to_tensor=self.teacher_model.has_differentiable_head
88+
list(x), convert_to_tensor=self.teacher_model.has_differentiable_head
8989
)
9090
cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student)
9191

9292
data_sampler = ContrastiveDistillationDataset(
93-
x, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
93+
list(x), cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
9494
)
9595
dataset = Dataset.from_list(list(data_sampler))
9696
loss = args.loss(self.model.model_body)
@@ -105,7 +105,8 @@ def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments]
105105
args (`TrainingArguments`, *optional*):
106106
Temporarily change the training arguments for this training call.
107107
"""
108-
y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head)
108+
with torch.no_grad():
109+
y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head)
109110
return super().train_classifier(x_train, y_train, args)
110111

111112

0 commit comments

Comments
 (0)