11"""PTuningScorer class for ptuning-based classification."""
22
33from pathlib import Path
4- from typing import Any , Literal
4+ from typing import TYPE_CHECKING , Any , Literal
55
6- from peft import PromptEncoderConfig , PromptEncoderReparameterizationType , TaskType , get_peft_model
76from pydantic import PositiveInt
87
98from autointent import Context
109from autointent ._callbacks import REPORTERS_NAMES
1110from autointent ._dump_tools import Dumper
11+ from autointent ._utils import require
1212from autointent .configs import EarlyStoppingConfig , HFModelConfig
1313from autointent .modules .scoring ._bert import BertScorer
1414
15+ if TYPE_CHECKING :
16+ from peft import PromptEncoderConfig
17+
1518
1619class PTuningScorer (BertScorer ):
1720 """PEFT P-tuning scorer.
@@ -47,6 +50,8 @@ class PTuningScorer(BertScorer):
4750
4851 name = "ptuning"
4952
53+ _ptuning_config : "PromptEncoderConfig"
54+
5055 def __init__ ( # noqa: PLR0913
5156 self ,
5257 classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
@@ -64,6 +69,13 @@ def __init__( # noqa: PLR0913
6469 print_progress : bool = False ,
6570 ** ptuning_kwargs : Any , # noqa: ANN401
6671 ) -> None :
72+ # Lazy import peft
73+ peft = require ("peft" , extra = "peft" )
74+ self ._PromptEncoderConfig = peft .PromptEncoderConfig
75+ self ._PromptEncoderReparameterizationType = peft .PromptEncoderReparameterizationType
76+ self ._TaskType = peft .TaskType
77+ self ._get_peft_model = peft .get_peft_model
78+
6779 super ().__init__ (
6880 classification_model_config = classification_model_config ,
6981 num_train_epochs = num_train_epochs ,
@@ -74,9 +86,9 @@ def __init__( # noqa: PLR0913
7486 early_stopping_config = early_stopping_config ,
7587 print_progress = print_progress ,
7688 )
77- self ._ptuning_config = PromptEncoderConfig (
78- task_type = TaskType .SEQ_CLS ,
79- encoder_reparameterization_type = PromptEncoderReparameterizationType (encoder_reparameterization_type ),
89+ self ._ptuning_config = self . _PromptEncoderConfig (
90+ task_type = self . _TaskType .SEQ_CLS ,
91+ encoder_reparameterization_type = self . _PromptEncoderReparameterizationType (encoder_reparameterization_type ),
8092 num_virtual_tokens = num_virtual_tokens ,
8193 encoder_dropout = encoder_dropout ,
8294 encoder_hidden_size = encoder_hidden_size ,
@@ -139,7 +151,7 @@ def from_context( # noqa: PLR0913
139151 def _initialize_model (self ) -> Any : # noqa: ANN401
140152 """Initialize the model with P-tuning configuration."""
141153 model = super ()._initialize_model ()
142- return get_peft_model (model , self ._ptuning_config )
154+ return self . _get_peft_model (model , self ._ptuning_config )
143155
144156 def dump (self , path : str ) -> None :
145- Dumper .dump (self , Path (path ), exclude = [PromptEncoderConfig ])
157+ Dumper .dump (self , Path (path ), exclude = [self . _PromptEncoderConfig ])
0 commit comments