Skip to content

Commit 5043777

Browse files
committed
add docstrings
1 parent e72bde5 commit 5043777

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

autointent/configs/_transformers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,18 @@ class CrossEncoderConfig(HFModelConfig):
127127

128128

129129
class EarlyStoppingConfig(BaseModel):
130-
val_fraction: float = 0.2
131-
patience: int = 1
132-
threshold: float = 0.0
133-
metric: Literal[tuple((SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS).keys())] | None = "scoring_f1" # type: ignore[valid-type]
130+
val_fraction: float = Field(
131+
0.2,
132+
description=(
133+
"Fraction of train samples to allocate to dev set to monitor quality "
134+
"during training and perofrm early stopping if quality doesn't enhances."
135+
),
136+
)
137+
patience: int = Field(1, description="Maximum number of epoches to wait for quality to enhance.")
138+
threshold: float = Field(
139+
0.0,
140+
description="Minimum quality increment to count it as enhancement. Default: any incremeant is counted",
141+
)
142+
metric: Literal[tuple((SCORING_METRICS_MULTILABEL | SCORING_METRICS_MULTICLASS).keys())] | None = Field( # type: ignore[valid-type]
143+
"scoring_f1", description="Metric to monitor."
144+
)

autointent/modules/scoring/_bert.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,48 @@
2929

3030

3131
class BertScorer(BaseScorer):
32+
"""Scoring module for transformer-based classification using BERT models.
33+
34+
This module uses a transformer model (like BERT) to perform intent classification.
35+
It supports both multiclass and multilabel classification tasks, with options for
36+
early stopping and various training configurations.
37+
38+
Args:
39+
classification_model_config: Config of the transformer model (HFModelConfig, str, or dict)
40+
num_train_epochs: Number of training epochs (default: 3)
41+
batch_size: Batch size for training (default: 8)
42+
learning_rate: Learning rate for training (default: 5e-5)
43+
seed: Random seed for reproducibility (default: 0)
44+
report_to: Reporting tool for training logs (e.g., "wandb", "tensorboard")
45+
early_stopping_config: Configuration for early stopping during training
46+
47+
Example:
48+
--------
49+
.. testcode::
50+
51+
from autointent.modules import BertScorer
52+
53+
# Initialize scorer with BERT model
54+
scorer = BertScorer(
55+
classification_model_config="bert-base-uncased",
56+
num_train_epochs=3,
57+
batch_size=8,
58+
learning_rate=5e-5,
59+
seed=42
60+
)
61+
62+
# Training data
63+
utterances = ["This is great!", "I didn't like it", "Awesome product", "Poor quality"]
64+
labels = [1, 0, 1, 0]
65+
66+
# Fit the model
67+
scorer.fit(utterances, labels)
68+
69+
# Make predictions
70+
test_utterances = ["Good product", "Not worth it"]
71+
probabilities = scorer.predict(test_utterances)
72+
"""
73+
3274
name = "bert"
3375
supports_multiclass = True
3476
supports_multilabel = True

0 commit comments

Comments
 (0)