From 13de0f42fb997fada4816494212c9a09a3bf5af7 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 16 Oct 2025 15:26:06 -0700 Subject: [PATCH] Remove chat environment --- src/forge/envs/__init__.py | 5 - src/forge/envs/chat.py | 212 ----------- tests/unit_tests/rl/environments/__init__.py | 5 - tests/unit_tests/rl/environments/test_chat.py | 331 ------------------ 4 files changed, 553 deletions(-) delete mode 100644 src/forge/envs/__init__.py delete mode 100644 src/forge/envs/chat.py delete mode 100644 tests/unit_tests/rl/environments/__init__.py delete mode 100644 tests/unit_tests/rl/environments/test_chat.py diff --git a/src/forge/envs/__init__.py b/src/forge/envs/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/src/forge/envs/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/src/forge/envs/chat.py b/src/forge/envs/chat.py deleted file mode 100644 index 24a5981a6..000000000 --- a/src/forge/envs/chat.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field - -import torch - -from forge.interfaces import Environment, Message, ModelTokenizer, Transform - -from forge.types import Action, Observation, State - - -@dataclass -class ChatAction(Action): - """Action for chat environments. - - Contains tokens that represent the action to be taken. - This interfaces directly with models. - """ - - tokens: torch.Tensor = field(default_factory=lambda: torch.tensor([])) - - def __post_init__(self): - """Validate required fields after initialization.""" - if self.tokens.numel() == 0: - raise ValueError("tokens is required and cannot be empty") - - -@dataclass -class ChatState(State): - """State of the ChatEnvironment containing message history.""" - - history_messages: list[Message] = field(default_factory=list) - history_tokens: list[torch.Tensor] = field( - default_factory=list - ) # Same len as messages - - -@dataclass -class ChatObservation(Observation): - """Observation returned by ChatEnvironment. - - Contains the message history in Huggingface format (list of dicts with role/content) - and the tokenized representation of the entire conversation. - - The environment owns the tokenizer and generates the tokens from the messages. - - Example: - messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "How tall is the Eiffel Tower?"}, - ] - tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation - """ - - messages: list[Message] = field(default_factory=list) - tokens: torch.Tensor = field(default_factory=lambda: torch.tensor([])) - # Inherited fields from Observation ABC: reward, done, metadata - - -class ChatEnvironment(Environment): - """A chat-based environment for LLMs, designed as a blank canvas for conversation and RL. - - This environment is designed to work with language models. It provides the fundamental structure - for managing conversation state but is intentionally minimal to allow maximum flexibility. - - The environment owns the tokenizer and is responsible for managing both message history and tokens. - Actions contain only tokens that interface directly with models. - - Args: - tokenizer: A tokenizer that will be used to tokenize the conversation - system_prompt: An optional system prompt string to use during reset calls (optional) - system_role: The role of the system (at reset time). Defaults to "system" - """ - - def __init__( - self, - tokenizer: ModelTokenizer, - system_prompt: str | None = None, - system_role: str = "system", - transform: Transform | None = None, - ): - super().__init__(transform=transform) - - if not hasattr(tokenizer, "apply_chat_template"): - raise ValueError("Tokenizer must have 'apply_chat_template' method") - self.tokenizer = tokenizer - self.system_prompt = system_prompt - self.system_role = system_role - - self._state = ChatState() - - if system_prompt: - system_message: Message = {"role": system_role, "content": system_prompt} - self._state.history_messages.append(system_message) - # Tokenize the system message - system_tokens = self.tokenizer.apply_chat_template( - conversation=[system_message], tokenize=True, return_tensors="pt" # type: ignore - ) - self._state.history_tokens.append(system_tokens) - - def reset(self) -> ChatObservation: - """Reset the environment to initial state. - - Returns: - ChatObservation: Initial observation with system prompt (if any) - """ - self._state.history_messages = [] - self._state.history_tokens = [] - if self.system_prompt: - system_message: Message = { - "role": self.system_role, - "content": self.system_prompt, - } - self._state.history_messages = [system_message] - # Tokenize the system message - system_tokens = self.tokenizer.apply_chat_template( - conversation=[system_message], tokenize=True, return_tensors="pt" # type: ignore - ) - self._state.history_tokens = [system_tokens] - - return self._create_observation() - - def step(self, action: ChatAction) -> ChatObservation: - """Take a step in the environment by adding tokens to the chat history. - - Args: - action: A ChatAction object containing tokens. - - Returns: - ChatObservation: The updated observation with the new tokens added. - """ - # Store the tokens directly from the action - self._state.history_tokens.append(action.tokens) - - # Decode tokens to text and add as a message to history - decoded_text = self.tokenizer.decode( - action.tokens.squeeze(), skip_special_tokens=True - ) - assistant_message: Message = {"role": "assistant", "content": decoded_text} - self._state.history_messages.append(assistant_message) - - return self._create_observation() - - def _create_observation(self) -> ChatObservation: - """Create a ChatObservation from the current state. - - Returns both the message history and the tokens flattened as a single tensor - ready to be used by models. - - Returns: - ChatObservation: Observation with messages and flattened tokens - """ - if self._state.history_tokens: - flattened_tokens = torch.cat(self._state.history_tokens, dim=0) - else: - flattened_tokens = torch.tensor([]) - - observation = ChatObservation( - messages=self._state.history_messages.copy(), # Copy to prevent external mutation - tokens=flattened_tokens, - ) - - transformed = self._apply_transform(observation) - if isinstance(transformed, ChatObservation): - return transformed - else: - # If transform returns base Observation, convert back to ChatObservation - return ChatObservation( - messages=getattr(transformed, "messages", []), - tokens=getattr(transformed, "tokens", torch.tensor([])), - done=transformed.done, - reward=transformed.reward, - ) - - @property - def state(self) -> ChatState: - """Get the current state of the environment. - - Returns: - ChatState: The current state. - """ - return self._state - - def message_to_action(self, message: Message) -> ChatAction: - """Convert a message dictionary to a ChatAction with tokens. - - Args: - message: Dictionary with 'role' and 'content' keys - - Returns: - ChatAction: A new ChatAction instance with tokenized content - - Raises: - ValueError: If required keys are missing - """ - if "role" not in message: - raise ValueError("Message must contain a 'role' key") - if "content" not in message: - raise ValueError("Message must contain a 'content' key") - if message["content"] is None: - raise ValueError("Message content cannot be None") - - # Tokenize the single message - tokens = self.tokenizer.apply_chat_template( - conversation=[message], tokenize=True, return_tensors="pt" # type: ignore - ) - - return ChatAction(tokens=tokens) diff --git a/tests/unit_tests/rl/environments/__init__.py b/tests/unit_tests/rl/environments/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/tests/unit_tests/rl/environments/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/rl/environments/test_chat.py b/tests/unit_tests/rl/environments/test_chat.py deleted file mode 100644 index 4abf89dc6..000000000 --- a/tests/unit_tests/rl/environments/test_chat.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from typing import Any, Optional -from unittest.mock import MagicMock - -import torch - -from forge.envs.chat import ( - ChatAction, - ChatEnvironment, - ChatObservation, - ChatState, - Message, -) - - -class MockTokenizer: - """Mock tokenizer implementing TokenizerProtocol for testing.""" - - def apply_chat_template( - self, - conversation: list[dict[str, str]], - tools: Optional[list[dict]] = None, - documents: Optional[list[dict[str, str]]] = None, - chat_template: Optional[str] = None, - add_generation_prompt: bool = False, - continue_final_message: bool = False, - tokenize: bool = True, - padding: bool = False, - truncation: bool = False, - max_length: Optional[int] = None, - return_tensors: Optional[str] = None, - return_dict: bool = False, - return_assistant_tokens_mask: bool = False, - tokenizer_kwargs: Optional[dict[str, Any]] = None, - **kwargs, - ) -> torch.Tensor: - """Mock implementation of apply_chat_template.""" - # For testing, we'll just return a tensor with a simple pattern based on the conversation - # Each message contributes 10 tokens to the output - return torch.tensor([[i for i in range(len(conversation) * 10)]]) - - def decode( - self, - token_ids: Any, - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: Optional[bool] = None, - **kwargs, - ) -> str: - """Mock implementation of decode.""" - # For testing, we'll just convert the tensor to a string - if isinstance(token_ids, torch.Tensor): - return f"Decoded: {token_ids.tolist()}" - return f"Decoded: {token_ids}" - - -class TestChatAction(unittest.TestCase): - """Test the ChatAction class.""" - - def test_init(self): - """Test initialization of ChatAction.""" - tokens = torch.tensor([1, 2, 3]) - action = ChatAction(tokens=tokens) - self.assertTrue(torch.equal(action.tokens, tokens)) - - def test_init_empty_tokens(self): - """Test initialization with empty tokens raises ValueError.""" - with self.assertRaises(ValueError): - ChatAction(tokens=torch.tensor([])) - - -class TestChatState(unittest.TestCase): - """Test the ChatState class.""" - - def test_init(self): - """Test initialization of ChatState.""" - state = ChatState() - self.assertEqual(state.history_messages, []) - self.assertEqual(state.history_tokens, []) - - def test_init_with_values(self): - """Test initialization with provided values.""" - messages: list[Message] = [{"role": "user", "content": "Hello"}] - tokens = [torch.tensor([1, 2, 3])] - state = ChatState(history_messages=messages, history_tokens=tokens) - self.assertEqual(state.history_messages, messages) - self.assertEqual(state.history_tokens, tokens) - - -class TestChatObservation(unittest.TestCase): - """Test the ChatObservation class.""" - - def test_init(self): - """Test initialization of ChatObservation.""" - obs = ChatObservation() - self.assertEqual(obs.messages, []) - self.assertEqual(obs.tokens.numel(), 0) - self.assertFalse(obs.done) - self.assertIsNone(obs.reward) - self.assertEqual(obs.metadata, {}) - - def test_init_with_values(self): - """Test initialization with provided values.""" - messages: list[Message] = [{"role": "user", "content": "Hello"}] - tokens = torch.tensor([1, 2, 3]) - obs = ChatObservation( - messages=messages, - tokens=tokens, - done=True, - reward=1.0, - metadata={"test": "value"}, - ) - self.assertEqual(obs.messages, messages) - self.assertTrue(torch.equal(obs.tokens, tokens)) - self.assertTrue(obs.done) - self.assertEqual(obs.reward, 1.0) - self.assertEqual(obs.metadata, {"test": "value"}) - - -class TestChatEnvironment(unittest.TestCase): - """Test the ChatEnvironment class.""" - - def setUp(self): - """Set up test fixtures.""" - self.tokenizer = MockTokenizer() - - def test_init_no_system_prompt(self): - """Test initialization without system prompt.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - self.assertEqual(env._state.history_messages, []) - self.assertEqual(env._state.history_tokens, []) - - def test_init_with_system_prompt(self): - """Test initialization with system prompt.""" - env = ChatEnvironment( - tokenizer=self.tokenizer, - system_prompt="You are a helpful assistant", - system_role="system", - ) - self.assertEqual(len(env._state.history_messages), 1) - self.assertEqual(env._state.history_messages[0]["role"], "system") - self.assertEqual( - env._state.history_messages[0]["content"], "You are a helpful assistant" - ) - self.assertEqual(len(env._state.history_tokens), 1) - - def test_init_invalid_tokenizer(self): - """Test initialization with invalid tokenizer.""" - # Create a mock with no attributes by setting spec=[] - invalid_tokenizer = MagicMock(spec=[]) - with self.assertRaises(ValueError): - ChatEnvironment(tokenizer=invalid_tokenizer) - - def test_reset_no_system_prompt(self): - """Test reset without system prompt.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # Add some history first - env._state.history_messages = [{"role": "user", "content": "Hello"}] # type: ignore - env._state.history_tokens = [torch.tensor([1, 2, 3])] - - # Reset should clear the history - obs = env.reset() - self.assertEqual(env._state.history_messages, []) - self.assertEqual(env._state.history_tokens, []) - self.assertEqual(obs.messages, []) - self.assertEqual(obs.tokens.numel(), 0) - - def test_reset_with_system_prompt(self): - """Test reset with system prompt.""" - env = ChatEnvironment( - tokenizer=self.tokenizer, - system_prompt="You are a helpful assistant", - system_role="system", - ) - # Add some history first - env._state.history_messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello"}, - ] # type: ignore - env._state.history_tokens = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] - - # Reset should clear the history and add the system prompt - obs = env.reset() - self.assertEqual(len(env._state.history_messages), 1) - self.assertEqual(env._state.history_messages[0]["role"], "system") - self.assertEqual( - env._state.history_messages[0]["content"], "You are a helpful assistant" - ) - self.assertEqual(len(env._state.history_tokens), 1) - self.assertEqual(len(obs.messages), 1) - self.assertEqual(obs.messages[0]["role"], "system") - self.assertEqual(obs.messages[0]["content"], "You are a helpful assistant") - - def test_step(self): - """Test step method.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - action = ChatAction(tokens=torch.tensor([1, 2, 3])) - - obs = env.step(action) - - # Check that the tokens were added to history - self.assertEqual(len(env._state.history_tokens), 1) - self.assertTrue( - torch.equal(env._state.history_tokens[0], torch.tensor([1, 2, 3])) - ) - - # Check that the message was added to history with decoded content - self.assertEqual(len(env._state.history_messages), 1) - self.assertEqual(env._state.history_messages[0]["role"], "assistant") - self.assertEqual( - env._state.history_messages[0]["content"], "Decoded: [1, 2, 3]" - ) - - # Check the observation - self.assertEqual(len(obs.messages), 1) - self.assertEqual(obs.messages[0]["role"], "assistant") - self.assertEqual(obs.messages[0]["content"], "Decoded: [1, 2, 3]") - - def test_create_observation(self): - """Test _create_observation method.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - env._state.history_messages = [ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello"}, - ] # type: ignore - env._state.history_tokens = [ - torch.tensor([[1, 2, 3]]), - torch.tensor([[4, 5, 6]]), - ] - - obs = env._create_observation() - - # Check the observation - self.assertEqual(len(obs.messages), 2) - self.assertEqual(obs.messages[0]["role"], "system") - self.assertEqual(obs.messages[0]["content"], "You are a helpful assistant") - self.assertEqual(obs.messages[1]["role"], "user") - self.assertEqual(obs.messages[1]["content"], "Hello") - - # Check that the tokens were concatenated - self.assertEqual(obs.tokens.numel(), 6) # 2 tensors of size 3 - - def test_create_observation_empty_history(self): - """Test _create_observation method with empty history.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - - obs = env._create_observation() - - # Check the observation - self.assertEqual(obs.messages, []) - self.assertEqual(obs.tokens.numel(), 0) - - def test_state_property(self): - """Test state property.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - state = env.state - self.assertIsInstance(state, ChatState) - self.assertEqual(state.history_messages, []) - self.assertEqual(state.history_tokens, []) - - def test_message_to_action(self): - """Test message_to_action method.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - message: Message = {"role": "user", "content": "Hello"} - - action = env.message_to_action(message) - - self.assertIsInstance(action, ChatAction) - self.assertEqual( - action.tokens.numel(), 10 - ) # Mock tokenizer returns 10 tokens per message - - def test_message_to_action_missing_role(self): - """Test message_to_action method with missing role.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # We're intentionally creating an invalid message to test error handling - message = {"content": "Hello"} # type: ignore - - with self.assertRaises(ValueError): - # Using type: ignore because we're intentionally passing an invalid message - env.message_to_action(message) # type: ignore - - def test_message_to_action_missing_content(self): - """Test message_to_action method with missing content.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # We're intentionally creating an invalid message to test error handling - message = {"role": "user"} # type: ignore - - with self.assertRaises(ValueError): - # Using type: ignore because we're intentionally passing an invalid message - env.message_to_action(message) # type: ignore - - def test_message_to_action_none_content(self): - """Test message_to_action method with None content.""" - env = ChatEnvironment(tokenizer=self.tokenizer) - # We're intentionally creating an invalid message to test error handling - message = {"role": "user", "content": None} # type: ignore - - with self.assertRaises(ValueError): - # Using type: ignore because we're intentionally passing an invalid message - env.message_to_action(message) # type: ignore - - def test_with_transform(self): - """Test environment with a transform.""" - - def transform(obs): - obs.metadata["transformed"] = True - obs.reward = 1.0 - return obs - - env = ChatEnvironment(tokenizer=self.tokenizer, transform=transform) - action = ChatAction(tokens=torch.tensor([1, 2, 3])) - - obs = env.step(action) - - self.assertTrue(obs.metadata.get("transformed")) - self.assertEqual(obs.reward, 1.0) - - -if __name__ == "__main__": - unittest.main()