Skip to content

Commit a6f5349

Browse files
improve type hints
1 parent 22f9385 commit a6f5349

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/agentlab/llm/response_api.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ class LLMOutput:
104104

105105
raw_response: Any = field(default=None)
106106
think: str = field(default="")
107-
action: str = field(default=None) # Default action if no tool call is made
108-
tool_calls: ToolCalls = field(default=None) # This will hold the tool call response if any
107+
action: str | None = field(default=None) # Default action if no tool call is made
108+
tool_calls: ToolCalls | None = field(
109+
default=None
110+
) # This will hold the tool call response if any
109111

110112

111113
class MessageBuilder:
@@ -374,10 +376,10 @@ def mark_all_previous_msg_for_caching(self):
374376

375377
@dataclass
376378
class APIPayload:
377-
messages: List[MessageBuilder | ToolCalls] = None
379+
messages: List[MessageBuilder] | None = None
378380
tools: List[Dict[str, Any]] | None = None
379381
tool_choice: Literal["none", "auto", "any", "required"] | None = None
380-
force_call_tool: str = (
382+
force_call_tool: str | None = (
381383
None # Name of the tool to call # If set, will force the LLM to call this tool.
382384
)
383385
use_cache_breakpoints: bool = (
@@ -410,7 +412,7 @@ def __init__(
410412
self.max_tokens = max_tokens
411413
super().__init__()
412414

413-
def __call__(self, payload: APIPayload) -> dict:
415+
def __call__(self, payload: APIPayload) -> LLMOutput:
414416
"""Make a call to the model and return the parsed response."""
415417
response = self._call_api(payload)
416418
return self._parse_response(response)
@@ -431,25 +433,29 @@ class AgentlabAction:
431433
Collection of utility function to convert tool calls to Agentlab action format.
432434
"""
433435

436+
@staticmethod
434437
def convert_toolcall_to_agentlab_action_format(toolcall: ToolCall) -> str:
435438
"""Convert a tool call to an Agentlab environment action string.
439+
436440
Args:
437441
toolcall: ToolCall object containing the name and arguments of the tool call.
438442
439443
Returns:
440-
str: A string representing the action in Agentlab format i.e. python function call string.
444+
A string representing the action in Agentlab format i.e. python function call string.
441445
"""
442446

443447
tool_name, tool_args = toolcall.name, toolcall.arguments
444448
return tool_call_to_python_code(tool_name, tool_args)
445449

446-
def convert_multiactions_to_agentlab_action_format(actions: list[str]) -> str:
450+
@staticmethod
451+
def convert_multiactions_to_agentlab_action_format(actions: list[str]) -> str | None:
447452
"""Convert multiple actions list to a format that env supports.
453+
448454
Args:
449455
actions: List of action strings to be joined.
450456
451457
Returns:
452-
str: Joined actions separated by newlines, or None if empty.
458+
Joined actions separated by newlines, or None if empty.
453459
"""
454460
return "\n".join(actions) if actions else None
455461

0 commit comments

Comments
 (0)