Skip to content

Commit a291de2

Browse files
author
patched.codes[bot]
committed
Patched patchwork/steps/CallLLM/CallLLM.py
1 parent 5b501d5 commit a291de2

File tree

1 file changed

+23
-38
lines changed

1 file changed

+23
-38
lines changed

patchwork/steps/CallLLM/CallLLM.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,7 @@ class _InnerCallLLMResponse:
2626

2727

2828
class CallLLM(Step, input_class=CallLLMInputs, output_class=CallLLMOutputs):
29-
def __init__(self, inputs: dict):
30-
super().__init__(inputs)
31-
# Set 'openai_key' from inputs or environment if not already set
32-
inputs.setdefault("openai_api_key", os.environ.get("OPENAI_API_KEY"))
33-
34-
prompt_file = inputs.get("prompt_file")
35-
if prompt_file is not None:
36-
prompt_file_path = Path(prompt_file)
37-
if not prompt_file_path.is_file():
38-
raise ValueError(f'Unable to find Prompt file: "{prompt_file}"')
39-
try:
40-
with open(prompt_file_path, "r") as fp:
41-
self.prompts = json.load(fp)
42-
except json.JSONDecodeError as e:
43-
raise ValueError(f'Invalid Json Prompt file "{prompt_file}": {e}')
44-
elif "prompts" in inputs.keys():
45-
self.prompts = inputs["prompts"]
46-
else:
47-
raise ValueError('Missing required data: "prompt_file" or "prompts"')
48-
49-
self.call_limit = int(inputs.get("max_llm_calls", -1))
50-
self.model_args = {key[len("model_") :]: value for key, value in inputs.items() if key.startswith("model_")}
51-
self.save_responses_to_file = inputs.get("save_responses_to_file", None)
52-
self.model = inputs.get("model", "gpt-4o-mini")
53-
self.allow_truncated = inputs.get("allow_truncated", False)
54-
self.file = inputs.get("file", None)
55-
self.client = AioLlmClient.create_aio_client(inputs)
56-
if self.client is None:
57-
raise ValueError(
58-
f"Model API key not found.\n"
59-
f'Please login at: "{TOKEN_URL}",\n'
60-
"Please go to the Integration's tab and generate an API key.\n"
61-
"Please copy the access token that is generated, "
62-
"and add `--patched_api_key=<token>` to the command line.\n"
63-
"\n"
64-
"If you are using an OpenAI API Key, please set `--openai_api_key=<token>`.\n"
65-
)
29+
6630

6731
def __persist_to_file(self, contents):
6832
# Convert relative path to absolute path
@@ -121,10 +85,22 @@ def __call(self, prompts: list[list[dict]]) -> list[_InnerCallLLMResponse]:
12185
kwargs["file"] = Path(self.file)
12286

12387
for prompt in prompts:
124-
is_input_accepted = self.client.is_prompt_supported(model=self.model, messages=prompt, **kwargs) > 0
88+
available_tokens = self.client.is_prompt_supported(model=self.model, messages=prompt, **kwargs)
89+
is_input_accepted = available_tokens > 0
90+
12591
if not is_input_accepted:
12692
self.set_status(StepStatus.WARNING, "Input token limit exceeded.")
12793
prompt = self.client.truncate_messages(prompt, self.model)
94+
95+
# Handle the case where model_max_tokens was set to -1
96+
# Calculate max_tokens based on available tokens from the model after prompt
97+
if hasattr(self, '_use_max_tokens') and self._use_max_tokens:
98+
if available_tokens > 0:
99+
kwargs['max_tokens'] = available_tokens
100+
logger.info(f"Setting max_tokens to {available_tokens} based on available model context")
101+
else:
102+
# If we can't determine available tokens, set a reasonable default
103+
logger.warning("Could not determine available tokens. Using model default.")
128104

129105
logger.trace(f"Message sent: \n{escape(indent(pformat(prompt), ' '))}")
130106
try:
@@ -184,4 +160,13 @@ def __parse_model_args(self) -> dict:
184160
else:
185161
new_model_args[key] = arg
186162

163+
# Handle special case for max_tokens = -1 (use maximum available tokens)
164+
if 'max_tokens' in new_model_args and new_model_args['max_tokens'] == -1:
165+
# Will be handled during the chat completion call
166+
logger.info("Using maximum available tokens for the model")
167+
del new_model_args['max_tokens'] # Remove it for now, we'll calculate it later
168+
self._use_max_tokens = True
169+
else:
170+
self._use_max_tokens = False
171+
187172
return new_model_args

0 commit comments

Comments
 (0)