@@ -26,43 +26,7 @@ class _InnerCallLLMResponse:
2626
2727
2828class CallLLM (Step , input_class = CallLLMInputs , output_class = CallLLMOutputs ):
29- def __init__ (self , inputs : dict ):
30- super ().__init__ (inputs )
31- # Set 'openai_key' from inputs or environment if not already set
32- inputs .setdefault ("openai_api_key" , os .environ .get ("OPENAI_API_KEY" ))
33-
34- prompt_file = inputs .get ("prompt_file" )
35- if prompt_file is not None :
36- prompt_file_path = Path (prompt_file )
37- if not prompt_file_path .is_file ():
38- raise ValueError (f'Unable to find Prompt file: "{ prompt_file } "' )
39- try :
40- with open (prompt_file_path , "r" ) as fp :
41- self .prompts = json .load (fp )
42- except json .JSONDecodeError as e :
43- raise ValueError (f'Invalid Json Prompt file "{ prompt_file } ": { e } ' )
44- elif "prompts" in inputs .keys ():
45- self .prompts = inputs ["prompts" ]
46- else :
47- raise ValueError ('Missing required data: "prompt_file" or "prompts"' )
48-
49- self .call_limit = int (inputs .get ("max_llm_calls" , - 1 ))
50- self .model_args = {key [len ("model_" ) :]: value for key , value in inputs .items () if key .startswith ("model_" )}
51- self .save_responses_to_file = inputs .get ("save_responses_to_file" , None )
52- self .model = inputs .get ("model" , "gpt-4o-mini" )
53- self .allow_truncated = inputs .get ("allow_truncated" , False )
54- self .file = inputs .get ("file" , None )
55- self .client = AioLlmClient .create_aio_client (inputs )
56- if self .client is None :
57- raise ValueError (
58- f"Model API key not found.\n "
59- f'Please login at: "{ TOKEN_URL } ",\n '
60- "Please go to the Integration's tab and generate an API key.\n "
61- "Please copy the access token that is generated, "
62- "and add `--patched_api_key=<token>` to the command line.\n "
63- "\n "
64- "If you are using an OpenAI API Key, please set `--openai_api_key=<token>`.\n "
65- )
29+
6630
6731 def __persist_to_file (self , contents ):
6832 # Convert relative path to absolute path
@@ -121,10 +85,22 @@ def __call(self, prompts: list[list[dict]]) -> list[_InnerCallLLMResponse]:
12185 kwargs ["file" ] = Path (self .file )
12286
12387 for prompt in prompts :
124- is_input_accepted = self .client .is_prompt_supported (model = self .model , messages = prompt , ** kwargs ) > 0
88+ available_tokens = self .client .is_prompt_supported (model = self .model , messages = prompt , ** kwargs )
89+ is_input_accepted = available_tokens > 0
90+
12591 if not is_input_accepted :
12692 self .set_status (StepStatus .WARNING , "Input token limit exceeded." )
12793 prompt = self .client .truncate_messages (prompt , self .model )
94+
95+ # Handle the case where model_max_tokens was set to -1
96+ # Calculate max_tokens based on available tokens from the model after prompt
97+ if hasattr (self , '_use_max_tokens' ) and self ._use_max_tokens :
98+ if available_tokens > 0 :
99+ kwargs ['max_tokens' ] = available_tokens
100+ logger .info (f"Setting max_tokens to { available_tokens } based on available model context" )
101+ else :
102+ # If we can't determine available tokens, set a reasonable default
103+ logger .warning ("Could not determine available tokens. Using model default." )
128104
129105 logger .trace (f"Message sent: \n { escape (indent (pformat (prompt ), ' ' ))} " )
130106 try :
@@ -184,4 +160,13 @@ def __parse_model_args(self) -> dict:
184160 else :
185161 new_model_args [key ] = arg
186162
163+ # Handle special case for max_tokens = -1 (use maximum available tokens)
164+ if 'max_tokens' in new_model_args and new_model_args ['max_tokens' ] == - 1 :
165+ # Will be handled during the chat completion call
166+ logger .info ("Using maximum available tokens for the model" )
167+ del new_model_args ['max_tokens' ] # Remove it for now, we'll calculate it later
168+ self ._use_max_tokens = True
169+ else :
170+ self ._use_max_tokens = False
171+
187172 return new_model_args
0 commit comments