11"""PTuningScorer class for ptuning-based classification."""
22
33from pathlib import Path
4- from typing import Any
4+ from typing import Any , Literal
55
6- from peft import PromptEncoderConfig , get_peft_model
6+ from peft import PromptEncoderConfig , PromptEncoderReparameterizationType , TaskType , get_peft_model
7+ from pydantic import PositiveInt
78
89from autointent import Context
910from autointent ._callbacks import REPORTERS_NAMES
@@ -53,14 +54,19 @@ class PTuningScorer(BertScorer):
5354
5455 name = "ptuning"
5556
56- def __init__ (
57+ def __init__ ( # noqa: PLR0913
5758 self ,
5859 classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
59- num_train_epochs : int = 3 ,
60- batch_size : int = 8 ,
60+ num_train_epochs : PositiveInt = 3 ,
61+ batch_size : PositiveInt = 8 ,
6162 learning_rate : float = 5e-5 ,
6263 seed : int = 0 ,
6364 report_to : REPORTERS_NAMES | None = None , # type: ignore[valid-type]
65+ encoder_reparameterization_type : Literal ["MLP" , "LSTM" ] = "LSTM" ,
66+ num_virtual_tokens : PositiveInt = 10 ,
67+ encoder_dropout : float = 0.1 ,
68+ encoder_hidden_size : PositiveInt = 128 ,
69+ encoder_num_layers : PositiveInt = 2 ,
6470 ** ptuning_kwargs : Any , # noqa: ANN401
6571 ) -> None :
6672 super ().__init__ (
@@ -71,17 +77,30 @@ def __init__(
7177 seed = seed ,
7278 report_to = report_to ,
7379 )
74- self ._ptuning_config = PromptEncoderConfig (task_type = "SEQ_CLS" , ** ptuning_kwargs )
80+ self ._ptuning_config = PromptEncoderConfig (
81+ task_type = TaskType .SEQ_CLS ,
82+ encoder_reparameterization_type = PromptEncoderReparameterizationType (encoder_reparameterization_type ),
83+ num_virtual_tokens = num_virtual_tokens ,
84+ encoder_dropout = encoder_dropout ,
85+ encoder_hidden_size = encoder_hidden_size ,
86+ encoder_num_layers = encoder_num_layers ,
87+ ** ptuning_kwargs ,
88+ )
7589
7690 @classmethod
77- def from_context (
91+ def from_context ( # noqa: PLR0913
7892 cls ,
7993 context : Context ,
8094 classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
81- num_train_epochs : int = 3 ,
82- batch_size : int = 8 ,
95+ num_train_epochs : PositiveInt = 3 ,
96+ batch_size : PositiveInt = 8 ,
8397 learning_rate : float = 5e-5 ,
8498 seed : int = 0 ,
99+ encoder_reparameterization_type : Literal ["MLP" , "LSTM" ] = "LSTM" ,
100+ num_virtual_tokens : PositiveInt = 10 ,
101+ encoder_dropout : float = 0.1 ,
102+ encoder_hidden_size : PositiveInt = 128 ,
103+ encoder_num_layers : PositiveInt = 2 ,
85104 ** ptuning_kwargs : Any , # noqa: ANN401
86105 ) -> "PTuningScorer" :
87106 """Create a PTuningScorer instance using a Context object.
@@ -93,6 +112,11 @@ def from_context(
93112 batch_size: Batch size for training
94113 learning_rate: Learning rate for training
95114 seed: Random seed for reproducibility
115+ encoder_reparameterization_type: Reparametrization type for the prompt encoder
116+ num_virtual_tokens: Number of virtual tokens for the prompt encoder
117+ encoder_dropout: Dropout for the prompt encoder
118+ encoder_hidden_size: Hidden size for the prompt encoder
119+ encoder_num_layers: Number of layers for the prompt encoder
96120 **ptuning_kwargs: Arguments for PromptEncoderConfig
97121 """
98122 if classification_model_config is None :
@@ -107,6 +131,11 @@ def from_context(
107131 learning_rate = learning_rate ,
108132 seed = seed ,
109133 report_to = report_to ,
134+ encoder_reparameterization_type = encoder_reparameterization_type ,
135+ num_virtual_tokens = num_virtual_tokens ,
136+ encoder_dropout = encoder_dropout ,
137+ encoder_hidden_size = encoder_hidden_size ,
138+ encoder_num_layers = encoder_num_layers ,
110139 ** ptuning_kwargs ,
111140 )
112141
0 commit comments