Skip to content

Commit 7d1d5d0

Browse files
committed
Update docs strings
1 parent dc1328d commit 7d1d5d0

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

autointent/modules/scoring/_ptuning/ptuning.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from datasets import Dataset
1010
from peft import PromptEncoderConfig, get_peft_model
11-
from transformers import ( # type: ignore[attr-defined]
11+
from transformers import (
1212
AutoModelForSequenceClassification,
1313
AutoTokenizer,
1414
DataCollatorWithPadding,
@@ -33,14 +33,32 @@ class PTuningScorer(BaseScorer):
3333
learning_rate: Learning rate for training
3434
seed: Random seed for reproducibility
3535
report_to: Reporting tool for training logs
36-
**ptuning_kwargs: Arguments for PromptEncoderConfig
36+
**ptuning_kwargs: Arguments for PromptEncoderConfig <https://huggingface.co/docs/peft/package_reference/p_tuning#peft.PromptEncoderConfig>
3737
3838
Example:
3939
--------
4040
.. testcode::
4141
42+
from autointent.modules import PTuningScorer
43+
scorer = PTuningScorer(
44+
base_model_config="prajjwal1/bert-tiny",
45+
num_train_epochs=3,
46+
batch_size=8,
47+
task_type="SEQ_CLS",
48+
num_virtual_tokens=10
49+
)
50+
utterances = ["hello", "goodbye", "allo", "sayonara"]
51+
labels = [0, 1, 0, 1]
52+
scorer.fit(utterances, labels)
53+
test_utterances = ["hi", "bye"]
54+
probabilities = scorer.predict(test_utterances)
55+
print(probabilities)
56+
4257
.. testoutput::
4358
59+
[[0.49624494 0.5037551 ]
60+
[0.5066545 0.4933455 ]]
61+
4462
"""
4563

4664
name = "ptuning"
@@ -170,15 +188,15 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
170188

171189
data_collator = DataCollatorWithPadding(tokenizer=self._tokenizer)
172190

173-
trainer = Trainer( # type: ignore[no-untyped-call]
191+
trainer = Trainer(
174192
model=self._model,
175193
args=training_args,
176194
train_dataset=tokenized_dataset,
177195
processing_class=self._tokenizer,
178196
data_collator=data_collator,
179197
)
180198

181-
trainer.train() # type: ignore[attr-defined]
199+
trainer.train()
182200

183201
self._model.eval()
184202

0 commit comments

Comments
 (0)