Skip to content

Commit b3d409f

Browse files
Add TODO's and WIP config classes
1 parent 997cc7b commit b3d409f

File tree

1 file changed

+96
-4
lines changed

1 file changed

+96
-4
lines changed

src/agentlab/llm/response_api.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
@dataclass
5151
class ToolCall:
52+
#TODO: Check if this is a suitable tool representation for being MCP compliant.
5253
name: str = field(default=None)
5354
arguments: Dict[str, Any] = field(default_factory=dict)
5455
raw_call: Any = field(default=None)
@@ -57,6 +58,8 @@ class ToolCall:
5758
@property
5859
def is_env_action(self) -> bool:
5960
"""Check if the tool call is a reserved BGYM action."""
61+
# TODO: env should return some func to check if agent action is env action.
62+
# Keep in mind env may or may not have a fixed set of reserved actions.
6063
return self.name in BGYM_RESERVED_ACTION_FUNCTION_NAMES
6164

6265
@property
@@ -86,11 +89,11 @@ def add_tool_call(self, tool_call: ToolCall) -> "ToolCalls":
8689
return self
8790

8891
def get_env_action_calls(self) -> List[ToolCall]:
89-
"""Get all tool calls that are reserved BGYM actions."""
92+
"""Get all tool calls that are reserved Environment actions."""
9093
return [call for call in self.tool_calls if call.is_env_action]
9194

9295
def get_non_env_action_calls(self) -> List[ToolCall]:
93-
"""Get all tool calls that are not reserved BGYM actions."""
96+
"""Get all tool calls that are not reserved Environment actions."""
9497
return [call for call in self.tool_calls if not call.is_env_action]
9598

9699
@property
@@ -125,7 +128,7 @@ class MessageBuilder:
125128
def __init__(self, role: str):
126129

127130
self.role = role
128-
self.last_raw_response: LLMOutput = None # NOTE: last_raw_response will be deprecated in future version.
131+
self.last_raw_response: LLMOutput = None # NOTE: last_raw_response will be deprecated in future version. We can use ToolCalls object to get all the relevant information.
129132
self.content: List[ContentItem] = []
130133
self.responsed_tool_calls: ToolCalls = None
131134

@@ -410,6 +413,85 @@ class BaseModelWithPricing(TrackAPIPricingMixin, BaseResponseModel):
410413
pass
411414

412415

416+
417+
# TODO: Define and use Flexible set of Configuration.
418+
# Below configs are not used and are WIP.
419+
# _______________________________________________________________
420+
421+
# Some High-level requirements.
422+
423+
# Env can have multiple actions sets. Each action set should be supported as tools and prompt description.
424+
# Env should have converstion functions to parse the tool calls or text to back to env actions.
425+
426+
# Backend LLMs or Large action models can have thier own action sets (Ui-Tars, CUA), which can be fixed or flexible.
427+
# EnvConfig or LLMConfig or ActionConfig should provide conversion from Backend LLM action to env_action.
428+
429+
# AgentLab Agents may emit multiple actions. EnvConfig should mention if it supports multiple actions in a single step.
430+
# If Env controller does not natively support multiactions. We can choose to integrate Env logic which brings this support.
431+
432+
# Env should broadcast what obersvations are supported and agent loop should be able to handle them. (e.g, Ax_tree)
433+
434+
@dataclass
435+
class ActionConfig:
436+
action_set: "AbstractActionSet" # TODO: Agentlab AbstractActionSet, have constructor methods to create actions as tools or descriptions with examples.
437+
multiaction: bool = True
438+
env_action_as_tools: bool = True # If True, action set is treated as tools
439+
tools: Optional[List[Dict[str, Any]]] = None # List of tool definitions or list of functions
440+
tool_text_descriptions: str = "" # Some description of the tools, emitted by the environment.
441+
tools_calls_to_env_action_parser: callable = # Some callable given by the environment to convert tool calls to env actions.
442+
text_to_env_action_parser: Optional[Type[MessageBuilder]] = None
443+
444+
@dataclass
445+
class ObsConfig
446+
# Check generic agent
447+
pass
448+
@dataclass
449+
class Config:
450+
model_args: BaseModelArgs
451+
obs: ObsConfig
452+
action: ActionConfig
453+
generationConfig: GenerationConfig
454+
455+
@dataclass
456+
class PromptConfig:
457+
# use_hints
458+
# use_summarizer
459+
pass
460+
@dataclass
461+
class ProviderConfig:
462+
"""Configuration for the LLM provider."""
463+
api_key_env_var: Optional[str] = None
464+
base_url: Optional[str] = None # Base URL for the API, if different
465+
# Anything else? # VLLM specific configurations ?, etc.
466+
@dataclass
467+
class LLMConfig:
468+
# backend LLM supported action set
469+
# Any other LLM specific configurations
470+
# Tool calling format?
471+
# Maybe include provider specific configurations here?
472+
473+
pass
474+
475+
@dataclass
476+
class GenerationConfig:
477+
temperature: float = 0.5
478+
max_new_tokens: int = 100
479+
# Might be useful for exploration to have the ability to modify inside agent loop.
480+
481+
@dataclass
482+
class APIPayload:
483+
messages: List[MessageBuilder | ToolCalls]
484+
api_endpoint: str
485+
api_key_env_var: Optional[str] = None
486+
base_url: Optional[str] = None
487+
tools: Optional[List[Dict[str, Any]]] = None # Taken from ActionConfig
488+
tool_choice: Optional[str] = None # Fix some literal value for tool choice, e.g., "auto" and convert according to the API. OpenAI and Anthrophic can have different tool choice parameters that behave differently.
489+
generation_config: GenerationConfig = GenerationConfig()
490+
caching: bool = False # If True, cache the response
491+
# The agent loop will form the payload based on the config and pass it to the API call.
492+
493+
# _______________________________________________________________
494+
413495
class OpenAIResponseModel(BaseModelWithPricing):
414496
def __init__(
415497
self,
@@ -437,6 +519,7 @@ def __init__(
437519
def _call_api(self, messages: list[Any | MessageBuilder], **kwargs) -> dict:
438520
input = self.convert_messages_to_api_format(messages)
439521

522+
#TODO: API/Payload Params should be a config dataclass. Update once settled on a config structure.
440523
api_params: Dict[str, Any] = {
441524
"model": self.model_name,
442525
"input": input,
@@ -486,8 +569,10 @@ def _parse_response(self, response: "OpenAIResponseObject") -> LLMOutput:
486569
tool_calls=toolcalls if toolcalls is not None else None,
487570
)
488571

572+
489573
def _extract_tool_calls_from_response(self, response: "OpenAIResponseObject") -> ToolCalls:
490574
"""Extracts tool calls from the response."""
575+
#TODO: Should this be in the BaseResponseModelclass?
491576
tool_calls = ToolCalls(raw_calls=response.output)
492577
for output in response.output:
493578
if output.type == "function_call":
@@ -507,6 +592,7 @@ def _extract_tool_calls_from_response(self, response: "OpenAIResponseObject") ->
507592

508593
def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None:
509594
"""Extracts actions from the response."""
595+
#TODO: Should this be in the BaseResponseModelclass? or Emitted by Environment?
510596
actions = []
511597
for call in toolcalls:
512598
if call.is_env_action:
@@ -530,8 +616,9 @@ def _extract_thinking_content_from_response(self, response: "OpenAIResponseObjec
530616
thinking_content += f"{output.output_text}\n"
531617
return thinking_content
532618

533-
# Environment Specific functions, in this case BGYM
619+
### Environment Specific functions, in this case BGYM ###
534620

621+
#TODO: Should the below functions be in the BaseResponseModelclass? or Emitted by the Environment and intialized using a config?
535622
def convert_toolcall_to_env_action_format(self, toolcall: ToolCall) -> str:
536623
"""Convert a tool call to an BGYM environment action string."""
537624
action_name, tool_args = toolcall.name, toolcall.arguments
@@ -554,6 +641,7 @@ def _extract_env_actions_from_text_response(self, response: "OpenAIResponseObjec
554641
pass
555642

556643

644+
# TODO: Refactor similar to OpenAIResponseModel
557645
class OpenAIChatCompletionModel(BaseModelWithPricing):
558646
def __init__(
559647
self,
@@ -684,6 +772,8 @@ def extract_content_with_reasoning(message, wrap_tag="think"):
684772
reasoning_content = ""
685773
return f"{reasoning_content}{msg_content}{message.get('content', '')}"
686774

775+
776+
# TODO: Refactor similar to OpenAIResponseModel
687777
class ClaudeResponseModel(BaseModelWithPricing):
688778
def __init__(
689779
self,
@@ -807,6 +897,8 @@ def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Mess
807897

808898

809899
# Factory classes to create the appropriate model based on the API endpoint.
900+
901+
# TODO: Do we really need these factory classes? how about implementing a _from_args() method in the BaseModelArgs class?
810902
@dataclass
811903
class OpenAIResponseModelArgs(BaseModelArgs):
812904
"""Serializable object for instantiating a generic chat model with an OpenAI

0 commit comments

Comments
 (0)