Skip to content

Commit 80e5146

Browse files
authored
Merge pull request #44 from jdkent/fix/openai_handle_disable
[FIX] ignore disable_abbreviation_expansion when passing to chat
2 parents 2350053 + 9267422 commit 80e5146

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

ns_extract/pipelines/api.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
env_variable: Optional[str] = None,
2222
env_file: Optional[str] = None,
2323
client_url: Optional[str] = None,
24+
disable_abbreviation_expansion: bool = False,
2425
**kwargs,
2526
):
2627
"""Initialize the prompt-based pipeline.
@@ -30,7 +31,8 @@ def __init__(
3031
env_variable: Environment variable containing API key
3132
env_file: Path to file containing API key
3233
client_url: Optional URL for OpenAI client
33-
**kwargs: Additional arguments for the completion function
34+
disable_abbreviation_expansion: If True, disables abbreviation expansion
35+
**kwargs: Additional arguments for the OpenAI completion function
3436
"""
3537
if not self._prompt:
3638
raise ValueError("Subclass must define _prompt template")
@@ -41,12 +43,17 @@ def __init__(
4143
self.env_variable = env_variable
4244
self.env_file = env_file
4345
self.client_url = client_url
44-
self.kwargs = kwargs
46+
47+
# Split parameters between publang and OpenAI
48+
self.text_processing_kwargs = {
49+
"disable_abbreviation_expansion": disable_abbreviation_expansion
50+
}
51+
self.completion_kwargs = kwargs
4552

4653
# Initialize OpenAI client
4754
self.client = self._load_client()
4855

49-
super().__init__()
56+
super().__init__(disable_abbreviation_expansion=disable_abbreviation_expansion)
5057

5158
def _load_client(self) -> OpenAI:
5259
"""Load the OpenAI client.
@@ -113,8 +120,8 @@ def _transform(self, inputs: dict, **kwargs) -> dict:
113120
],
114121
"output_schema": self._extraction_schema.model_json_schema(),
115122
}
116-
if self.kwargs:
117-
completion_config.update(self.kwargs)
123+
if self.completion_kwargs:
124+
completion_config.update(self.completion_kwargs)
118125

119126
# Replace $ with $$ to escape $ signs in the prompt
120127
# (otherwise interpreted as a special character by Template())

0 commit comments

Comments
 (0)