From 33f898fd5365ba5bc898a74ab484d0ddc94731c2 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Fri, 24 Jan 2025 14:55:24 -0800 Subject: [PATCH 1/4] make message input/output configurabe --- ldp/nn/graph/llm_call_op.py | 175 ++++++++++++++++++++---------------- 1 file changed, 96 insertions(+), 79 deletions(-) diff --git a/ldp/nn/graph/llm_call_op.py b/ldp/nn/graph/llm_call_op.py index 88b727ef..1dae965f 100644 --- a/ldp/nn/graph/llm_call_op.py +++ b/ldp/nn/graph/llm_call_op.py @@ -1,12 +1,19 @@ from __future__ import annotations import json +from collections.abc import Callable from functools import partial from typing import Any, ClassVar import tree -from aviary.core import MalformedMessageError, Message, Messages -from aviary.tools import Tool, ToolCall, ToolRequestMessage +from aviary.core import ( + MalformedMessageError, + Message, + Messages, + ToolCall, + ToolRequestMessage, + Tools, +) from transformers import LogitsProcessorList from ldp.graph.gradient_estimators import assign_constant_grads @@ -24,6 +31,76 @@ from ..lm_config import LMConfig # noqa: TID252 +def default_prep_messages_for_tokenizer(msgs: Messages) -> list[dict]: + """A default messages prep function following the Llama 3.1 syntax.""" + result: list[dict] = [] + for msg in msgs: + content = msg.content + if isinstance(msg, ToolRequestMessage): + assert len(msg.tool_calls) == 1, ( + "Support parsing only single tool call for now" + ) + tool_call = msg.tool_calls[0] + # TODO: document where this format is coming from. Is this a Huggingface chat template syntax? + content_dict = { + "name": tool_call.function.name, + "parameters": tool_call.function.arguments, + "thought": msg.content, + } + content = json.dumps(content_dict) + assert content is not None, "content is None, doesn't make sense" + + result.append({"role": msg.role, "content": content}) + return result + + +def default_prep_tools_for_tokenizer(tools: Tools | None) -> list[dict] | None: + """A default tools prep function following the Llama 3.1 syntax.""" + if not tools: + return None + + # TODO: should be able to switch to tool.info.model_dump() here + return [ + { + "name": tool.info.name, + "description": tool.info.description, + "parameters": { + "type": tool.info.parameters.type, + "properties": { + prop_name: { + "type": prop_details.get("type"), + "description": prop_details.get("description"), + "title": prop_details.get("title"), + } + for prop_name, prop_details in tool.info.parameters.properties.items() + }, + "required": tool.info.parameters.required, + }, + } + for tool in tools + ] + + +def default_parse_tool_request_message( + out_text: str, tools: Tools +) -> ToolRequestMessage: + """A default tool request message parsing following the Llama 3.1 syntax.""" + try: + tool_request = json.loads(out_text) + tool_name = tool_request["name"] + tool = next(t for t in tools if t.info.name == tool_name) + tool_thought = tool_request.get("thought", "") + tool_parameters = tool_request.get("parameters", {}) + return ToolRequestMessage( + tool_calls=[ToolCall.from_tool(tool, **tool_parameters)], + content=tool_thought, + ) + except StopIteration as exc: + raise MalformedMessageError(f"Tool {tool_name} not found in tools.") from exc + except json.JSONDecodeError as err: + raise ValueError(f"Failed to parse tools call message: {out_text}") from err + + class LocalLLMCallOp(Op[Message]): """An Op that samples a token sequence from a local language model.""" @@ -39,6 +116,15 @@ def __init__( batch_size: int = 1, max_wait_interval: float = 0.1, parallel_mode_config: ParallelModeConfig | None = None, + prep_messages_for_tokenizer: Callable[ + [Messages], list[dict] + ] = default_prep_messages_for_tokenizer, + prep_tools_for_tokenizer: Callable[ + [Tools | None], list[dict] | None + ] = default_prep_tools_for_tokenizer, + parse_tool_request_message: Callable[ + [str, Tools], ToolRequestMessage + ] = default_parse_tool_request_message, ) -> None: super().__init__() @@ -61,91 +147,22 @@ def __init__( self.model_handler = handler_config.make_async_module() self.model_name = model_config.model - self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()} + self.prep_messages_for_tokenizer = prep_messages_for_tokenizer + self.prep_tools_for_tokenizer = prep_tools_for_tokenizer + self.parse_tool_request_message = parse_tool_request_message - @staticmethod - def prep_messages_for_tokenizer(xi: Messages) -> list[dict]: - result: list[dict] = [] - for msg in xi: - content = msg.content - if isinstance(msg, ToolRequestMessage): - assert len(msg.tool_calls) == 1, ( - "Support parsing only single tool call for now" - ) - tool_call = msg.tool_calls[0] - # TODO: document where this format is coming from. Is this a Huggingface chat template syntax? - content_dict = { - "name": tool_call.function.name, - "parameters": tool_call.function.arguments, - "thought": msg.content, - } - content = json.dumps(content_dict) - assert content is not None, "content is None, doesn't make sense" - - result.append({"role": msg.role, "content": content}) - return result - - @staticmethod - def prep_tools_for_tokenizer(tools: list[Tool] | None) -> list[dict] | None: - """Prepare tools for the tokenizer by transforming them into a JSON schema.""" - if not tools: - return None - - # TODO: should be able to switch to tool.info.model_dump() here - return [ - { - "name": tool.info.name, - "description": tool.info.description, - "parameters": { - "type": tool.info.parameters.type, - "properties": { - prop_name: { - "type": prop_details.get("type"), - "description": prop_details.get("description"), - "title": prop_details.get("title"), - } - for prop_name, prop_details in tool.info.parameters.properties.items() - }, - "required": tool.info.parameters.required, - }, - } - for tool in tools - ] - - @staticmethod - def _parse_tool_request(out_text: str, tools: list[Tool]) -> ToolRequestMessage: - """Parse the output text to extract the tool request. - - TODO: see if this needs to be configurable, e.g. for different model - output formats that we want to experiment with. - """ - try: - tool_request = json.loads(out_text) - tool_name = tool_request["name"] - tool = next(t for t in tools if t.info.name == tool_name) - tool_thought = tool_request.get("thought", "") - tool_parameters = tool_request.get("parameters", {}) - return ToolRequestMessage( - tool_calls=[ToolCall.from_tool(tool, **tool_parameters)], - content=tool_thought, - ) - except StopIteration as exc: - raise MalformedMessageError( - f"Tool {tool_name} not found in tools." - ) from exc - except json.JSONDecodeError as err: - raise ValueError(f"Failed to parse tools call message: {out_text}") from err + self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()} async def forward( self, - xi: list[Message], + msgs: list[Message], temperature: float = 1.0, max_new_tokens: int = 10, - tools: list[Tool] | None = None, + tools: Tools | None = None, **kwargs: dict[str, Any], ) -> Message: call_id = get_call_id() - inputs = self.prep_messages_for_tokenizer(xi) + inputs = self.prep_messages_for_tokenizer(msgs) tools_json = self.prep_tools_for_tokenizer(tools) if get_training_mode(): self.ctx.update(call_id, LocalLLMCallOp.CTX_INPUTS_PREP_KEY, inputs) @@ -166,7 +183,7 @@ async def forward( out_msg = Message(role="assistant", content=out_text) if tools and out_text.startswith("{"): - out_msg = self._parse_tool_request(out_text, tools) + out_msg = self.parse_tool_request_message(out_text, tools) if get_training_mode(): self.ctx.update( From 572bba0f7750744597874cb8698929d14878f015 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Fri, 24 Jan 2025 18:43:44 -0800 Subject: [PATCH 2/4] switch to a class --- ldp/nn/graph/llm_call_op.py | 183 +++++++++++++++---------- ldp/nn/handlers/transformer_handler.py | 10 +- 2 files changed, 118 insertions(+), 75 deletions(-) diff --git a/ldp/nn/graph/llm_call_op.py b/ldp/nn/graph/llm_call_op.py index 1dae965f..03b38efd 100644 --- a/ldp/nn/graph/llm_call_op.py +++ b/ldp/nn/graph/llm_call_op.py @@ -1,7 +1,8 @@ from __future__ import annotations import json -from collections.abc import Callable +import logging +from abc import ABC, abstractmethod from functools import partial from typing import Any, ClassVar @@ -30,75 +31,120 @@ ) from ..lm_config import LMConfig # noqa: TID252 +logger = logging.getLogger(__name__) -def default_prep_messages_for_tokenizer(msgs: Messages) -> list[dict]: - """A default messages prep function following the Llama 3.1 syntax.""" - result: list[dict] = [] - for msg in msgs: - content = msg.content + +class MessageAndToolParser(ABC): + """Base class to define how we translate between (messages, tools) and strings.""" + + supported_templates: ClassVar[set[str]] = set() + + @abstractmethod + @classmethod + def get_message_content(cls, msg: Message) -> str | None: + """Represents a message as a string.""" + + @abstractmethod + @classmethod + def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None: + """Prepares tools for tokenization.""" + + @abstractmethod + @classmethod + def parse_tool_request_message( + cls, out_text: str, tools: Tools + ) -> ToolRequestMessage: + """Parses the output text from a tool request message.""" + + @classmethod + def prep_messages_for_tokenizer(cls, msgs: Messages) -> list[dict]: + """Prepares message history for tokenization.""" + result: list[dict] = [] + for msg in msgs: + content = cls.get_message_content(msg) + assert content is not None, f"Content should not be None: {msg!r}" + result.append({"role": msg.role, "content": content}) + return result + + +class Llama31Parser(MessageAndToolParser): + """Follows the Llama 3.1 syntax. + + See details: + https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#-tool-calling-(8b/70b/405b)- + """ + + supported_templates: ClassVar[set[str]] = { + "llama2_chat_template_ori.jinja", + "llama3.1_chat_template_hf.jinja", + "llama3.1_chat_template_nothought.jinja", + "llama3.1_chat_template_thought.jinja", + "llama3.1_chat_template_vllm.jinja", + "llama3_chat_template_ori.jinja", + } + + @classmethod + def get_message_content(cls, msg: Message) -> str | None: if isinstance(msg, ToolRequestMessage): assert len(msg.tool_calls) == 1, ( "Support parsing only single tool call for now" ) tool_call = msg.tool_calls[0] - # TODO: document where this format is coming from. Is this a Huggingface chat template syntax? content_dict = { "name": tool_call.function.name, "parameters": tool_call.function.arguments, "thought": msg.content, } - content = json.dumps(content_dict) - assert content is not None, "content is None, doesn't make sense" - - result.append({"role": msg.role, "content": content}) - return result - - -def default_prep_tools_for_tokenizer(tools: Tools | None) -> list[dict] | None: - """A default tools prep function following the Llama 3.1 syntax.""" - if not tools: - return None - - # TODO: should be able to switch to tool.info.model_dump() here - return [ - { - "name": tool.info.name, - "description": tool.info.description, - "parameters": { - "type": tool.info.parameters.type, - "properties": { - prop_name: { - "type": prop_details.get("type"), - "description": prop_details.get("description"), - "title": prop_details.get("title"), - } - for prop_name, prop_details in tool.info.parameters.properties.items() + return json.dumps(content_dict) + + return msg.content + + @classmethod + def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None: + if not tools: + return None + + # TODO: should be able to switch to tool.info.model_dump() here + return [ + { + "name": tool.info.name, + "description": tool.info.description, + "parameters": { + "type": tool.info.parameters.type, + "properties": { + prop_name: { + "type": prop_details.get("type"), + "description": prop_details.get("description"), + "title": prop_details.get("title"), + } + for prop_name, prop_details in tool.info.parameters.properties.items() + }, + "required": tool.info.parameters.required, }, - "required": tool.info.parameters.required, - }, - } - for tool in tools - ] - - -def default_parse_tool_request_message( - out_text: str, tools: Tools -) -> ToolRequestMessage: - """A default tool request message parsing following the Llama 3.1 syntax.""" - try: - tool_request = json.loads(out_text) - tool_name = tool_request["name"] - tool = next(t for t in tools if t.info.name == tool_name) - tool_thought = tool_request.get("thought", "") - tool_parameters = tool_request.get("parameters", {}) - return ToolRequestMessage( - tool_calls=[ToolCall.from_tool(tool, **tool_parameters)], - content=tool_thought, - ) - except StopIteration as exc: - raise MalformedMessageError(f"Tool {tool_name} not found in tools.") from exc - except json.JSONDecodeError as err: - raise ValueError(f"Failed to parse tools call message: {out_text}") from err + } + for tool in tools + ] + + @classmethod + def parse_tool_request_message( + cls, out_text: str, tools: Tools + ) -> ToolRequestMessage: + try: + tool_request = json.loads(out_text) + tool_name = tool_request["name"] + tool = next(t for t in tools if t.info.name == tool_name) + tool_thought = tool_request.get("thought", "") + tool_parameters = tool_request.get("parameters", {}) + return ToolRequestMessage( + tool_calls=[ToolCall.from_tool(tool, **tool_parameters)], + content=tool_thought, + ) + except StopIteration as exc: + raise MalformedMessageError( + f"Tool {tool_name} not found in tools." + ) from exc + except json.JSONDecodeError as err: + raise ValueError(f"Failed to parse tools call message: {out_text}") from err class LocalLLMCallOp(Op[Message]): @@ -116,15 +162,7 @@ def __init__( batch_size: int = 1, max_wait_interval: float = 0.1, parallel_mode_config: ParallelModeConfig | None = None, - prep_messages_for_tokenizer: Callable[ - [Messages], list[dict] - ] = default_prep_messages_for_tokenizer, - prep_tools_for_tokenizer: Callable[ - [Tools | None], list[dict] | None - ] = default_prep_tools_for_tokenizer, - parse_tool_request_message: Callable[ - [str, Tools], ToolRequestMessage - ] = default_parse_tool_request_message, + parser: type[MessageAndToolParser] = Llama31Parser, ) -> None: super().__init__() @@ -147,9 +185,14 @@ def __init__( self.model_handler = handler_config.make_async_module() self.model_name = model_config.model - self.prep_messages_for_tokenizer = prep_messages_for_tokenizer - self.prep_tools_for_tokenizer = prep_tools_for_tokenizer - self.parse_tool_request_message = parse_tool_request_message + self.prep_messages_for_tokenizer = parser.prep_messages_for_tokenizer + self.prep_tools_for_tokenizer = parser.prep_tools_for_tokenizer + self.parse_tool_request_message = parser.parse_tool_request_message + if model_config.chat_template not in parser.supported_templates: + logger.warning( + f"Chat template {model_config.chat_template!r} not in " + f"{parser.__class__.__name__}.supported templates." + ) self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()} diff --git a/ldp/nn/handlers/transformer_handler.py b/ldp/nn/handlers/transformer_handler.py index 6fd7b740..8561b355 100644 --- a/ldp/nn/handlers/transformer_handler.py +++ b/ldp/nn/handlers/transformer_handler.py @@ -66,15 +66,15 @@ TParams = ParamSpec("TParams") -def is_conversation(messages) -> bool: - """Check if messages is an instance of Conversation.""" - return isinstance(messages, list) and all( +def is_message_history(maybe_messages) -> bool: + """Check if input is a message history encoded as list of dict[str, str].""" + return isinstance(maybe_messages, list) and all( isinstance(msg, dict) and all( isinstance(key, str) and isinstance(value, str) for key, value in msg.items() ) - for msg in messages + for msg in maybe_messages ) @@ -894,7 +894,7 @@ def _get_tokenized_inputs( return BatchEncoding(inputs) if isinstance(inputs, str): return tokenizer(inputs, return_tensors="pt") - if is_conversation(inputs): + if is_message_history(inputs): return tokenizer.apply_chat_template( inputs, tools=tools_json, From 1e7382d6a62812bf24a50e8926f8f586280c6f83 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Fri, 24 Jan 2025 18:49:21 -0800 Subject: [PATCH 3/4] fix order --- ldp/nn/graph/llm_call_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ldp/nn/graph/llm_call_op.py b/ldp/nn/graph/llm_call_op.py index 03b38efd..4d1c36ec 100644 --- a/ldp/nn/graph/llm_call_op.py +++ b/ldp/nn/graph/llm_call_op.py @@ -39,18 +39,18 @@ class MessageAndToolParser(ABC): supported_templates: ClassVar[set[str]] = set() - @abstractmethod @classmethod + @abstractmethod def get_message_content(cls, msg: Message) -> str | None: """Represents a message as a string.""" - @abstractmethod @classmethod + @abstractmethod def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None: """Prepares tools for tokenization.""" - @abstractmethod @classmethod + @abstractmethod def parse_tool_request_message( cls, out_text: str, tools: Tools ) -> ToolRequestMessage: From 402fd891ff4bb087d4fc58effa543d3f01672b2f Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Fri, 24 Jan 2025 19:06:45 -0800 Subject: [PATCH 4/4] remove kwargs typing --- ldp/nn/graph/llm_call_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldp/nn/graph/llm_call_op.py b/ldp/nn/graph/llm_call_op.py index 4d1c36ec..9a5426e5 100644 --- a/ldp/nn/graph/llm_call_op.py +++ b/ldp/nn/graph/llm_call_op.py @@ -4,7 +4,7 @@ import logging from abc import ABC, abstractmethod from functools import partial -from typing import Any, ClassVar +from typing import ClassVar import tree from aviary.core import ( @@ -202,7 +202,7 @@ async def forward( temperature: float = 1.0, max_new_tokens: int = 10, tools: Tools | None = None, - **kwargs: dict[str, Any], + **kwargs, ) -> Message: call_id = get_call_id() inputs = self.prep_messages_for_tokenizer(msgs)