@@ -21,6 +21,7 @@ def __init__(
2121 env_variable : Optional [str ] = None ,
2222 env_file : Optional [str ] = None ,
2323 client_url : Optional [str ] = None ,
24+ disable_abbreviation_expansion : bool = False ,
2425 ** kwargs ,
2526 ):
2627 """Initialize the prompt-based pipeline.
@@ -30,7 +31,8 @@ def __init__(
3031 env_variable: Environment variable containing API key
3132 env_file: Path to file containing API key
3233 client_url: Optional URL for OpenAI client
33- **kwargs: Additional arguments for the completion function
34+ disable_abbreviation_expansion: If True, disables abbreviation expansion
35+ **kwargs: Additional arguments for the OpenAI completion function
3436 """
3537 if not self ._prompt :
3638 raise ValueError ("Subclass must define _prompt template" )
@@ -41,12 +43,17 @@ def __init__(
4143 self .env_variable = env_variable
4244 self .env_file = env_file
4345 self .client_url = client_url
44- self .kwargs = kwargs
46+
47+ # Split parameters between publang and OpenAI
48+ self .text_processing_kwargs = {
49+ "disable_abbreviation_expansion" : disable_abbreviation_expansion
50+ }
51+ self .completion_kwargs = kwargs
4552
4653 # Initialize OpenAI client
4754 self .client = self ._load_client ()
4855
49- super ().__init__ ()
56+ super ().__init__ (disable_abbreviation_expansion = disable_abbreviation_expansion )
5057
5158 def _load_client (self ) -> OpenAI :
5259 """Load the OpenAI client.
@@ -113,8 +120,8 @@ def _transform(self, inputs: dict, **kwargs) -> dict:
113120 ],
114121 "output_schema" : self ._extraction_schema .model_json_schema (),
115122 }
116- if self .kwargs :
117- completion_config .update (self .kwargs )
123+ if self .completion_kwargs :
124+ completion_config .update (self .completion_kwargs )
118125
119126 # Replace $ with $$ to escape $ signs in the prompt
120127 # (otherwise interpreted as a special character by Template())
0 commit comments