diff --git a/README.md b/README.md index b1964314..bd04999b 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,18 @@ The InstructLab Training library is an optimized model instruction-tuning librar To simplify the process of fine-tuning models with the [LAB method](https://arxiv.org/abs/2403.01081), or for general use, this library provides a simple pythonic training interface. +### Reasoning Content Support + +The library now supports reasoning traces through the `reasoning_content` field in message samples. This enables training models that can handle both regular content and structured reasoning traces, making it ideal for training reasoning-capable models that can separate their thinking process from their final output. + ## Usage and Guidance Sections - [Installing](#installing-the-library) - [Additional Nvidia packages](#additional-nvidia-packages) - [Using the library](#using-the-library) +- [Data format](#data-format) + - [Reasoning content support](#reasoning-content-support-1) +- [Documentation](#documentation) - [Learning about the training arguments](#learning-about-training-arguments) - [`TrainingArgs`](#trainingargs) - [`DeepSpeedOptions`](#deepspeedoptions) @@ -80,6 +87,72 @@ You can then define various training arguments. They will serve as the parameter - [Learning about the training argument](#learning-about-training-arguments) - [Example training run with arguments](#example-training-run-with-arguments) +## Data format + +The library expects training data in the messages format, where each sample contains a list of messages with different roles (user, assistant, system, etc.). Each message should have at minimum: + +- `role`: The role of the message sender (e.g., "user", "assistant", "system") +- `content`: The main content of the message + +### Reasoning content support + +The library now supports an optional `reasoning_content` field in addition to the standard `content` field. This enables training models with structured reasoning traces. The `reasoning_content` field is particularly useful for: + +- Training reasoning-capable models that can separate their thinking process from their output +- Supporting models that need to generate internal reasoning traces +- Enabling step-by-step reasoning in model responses + +> **Note**: this is only supported for models with chat templates that use the DeepSeek R1-style parser. Models without a custom thought processor such as Phi-4 must still provide their reasoning traces in the `content` field. + +**Example message structure with reasoning content:** + +```json +{ + "messages": [ + { + "role": "user", + "content": "What is 15 * 23?" + }, + { + "role": "assistant", + "reasoning_content": "I need to multiply 15 by 23. Let me break this down: 15 * 23 = 15 * (20 + 3) = 15 * 20 + 15 * 3 = 300 + 45 = 345", + "content": "15 * 23 = 345" + } + ] +} +``` + +**Standard message structure:** + +```json +{ + "messages": [ + { + "role": "user", + "content": "Hello! How are you?" + }, + { + "role": "assistant", + "content": "Hello! I'm doing well, thank you for asking. How can I help you today?" + } + ] +} +``` + +#### Important Notes + +1. **Automatic reasoning content processing**: If `reasoning_content` exists in a message, it will always be processed and unmasked as long as the message role is targeted for unmasking. This ensures that reasoning traces are properly included in the training data. + +2. **DeepSeek R1 Thinking Compatibility**: Models using the DeepSeek R1 thought processor (such as Qwen3) must supply their thinking traces in the `reasoning_content` field to be processed correctly. Failure to do so may result in improper handling of reasoning tokens and suboptimal training performance. + +## Documentation + +For detailed information about specific features: + +- **[Reasoning Content Support](docs/reasoning_content.md)**: Comprehensive guide to using the `reasoning_content` field for training reasoning-capable models +- **[CI Documentation](docs/ci.md)**: Information about continuous integration processes +- **[Logging Documentation](docs/logging.md)**: Guide to logging configuration and usage + ## Learning about training arguments The `TrainingArgs` class provides most of the customization options @@ -378,4 +451,4 @@ Below is a list of custom environment variables users can set in the training li ## Developer Certificate of Origin -When you make a contribution to InstructLab training, you implicitly agree to the Developer Certificate of Origin terms as set in `DCO.txt` at the root of this repository. \ No newline at end of file +When you make a contribution to InstructLab training, you implicitly agree to the Developer Certificate of Origin terms as set in `DCO.txt` at the root of this repository. diff --git a/docs/reasoning_content.md b/docs/reasoning_content.md new file mode 100644 index 00000000..766b39c6 --- /dev/null +++ b/docs/reasoning_content.md @@ -0,0 +1,181 @@ +# Reasoning Content Support + +The InstructLab Training library supports structured reasoning traces through the `reasoning_content` field in message samples. This feature enables training models that can separate their thinking process from their final output. + +## Overview + +The `reasoning_content` field is an optional addition to the standard message format that allows you to include the model's internal reasoning process alongside the final response. This is particularly useful for: + +- Training reasoning-capable models that show their work +- Supporting models that need to generate step-by-step reasoning +- Enabling chain-of-thought style training data +- Separating internal thinking from user-facing responses + +## Message Format + +### Standard Message Format + +```json +{ + "role": "assistant", + "content": "The answer is 42." +} +``` + +### Extended Message Format with Reasoning Content + +```json +{ + "role": "assistant", + "content": "The answer is 42.", + "reasoning_content": "Let me think about this step by step. The question asks for the meaning of life, and according to The Hitchhiker's Guide to the Galaxy, the answer is 42." +} +``` + +## Data Processing Behavior + +When processing messages during training: + +1. **Unmasking Rules**: Both `content` and `reasoning_content` fields follow the same unmasking rules based on the message role +2. **Template Integration**: Both fields are processed by the chat template and included in the tokenized output +3. **Token Wrapping**: If a role is configured to be unmasked, both fields (when present) are wrapped with unmask tokens +4. **Independent Fields**: Either field can exist independently - messages can have only `content`, only `reasoning_content`, or both + +## Usage Examples + +### Training Data with Reasoning Traces + +```json +{ + "messages": [ + { + "role": "user", + "content": "What is 15 * 23?" + }, + { + "role": "assistant", + "reasoning_content": "I need to multiply 15 by 23. Let me break this down: 15 * 23 = 15 * (20 + 3) = 15 * 20 + 15 * 3 = 300 + 45 = 345", + "content": "15 * 23 = 345" + } + ] +} +``` + +### Mixed Content Types + +```json +{ + "messages": [ + { + "role": "user", + "content": "Solve this math problem step by step: 2x + 5 = 13" + }, + { + "role": "assistant", + "reasoning_content": "I need to solve for x. First, I'll subtract 5 from both sides: 2x = 8. Then divide by 2: x = 4.", + "content": "To solve 2x + 5 = 13:\n1. Subtract 5 from both sides: 2x = 8\n2. Divide by 2: x = 4\n\nTherefore, x = 4." + } + ] +} +``` + +### Reasoning-Only Responses + +```json +{ + "messages": [ + { + "role": "user", + "content": "Think about the implications of AI safety." + }, + { + "role": "assistant", + "reasoning_content": "This is a complex topic that requires careful consideration of multiple factors including alignment, capability control, and social implications..." + } + ] +} +``` + +## Implementation Details + +### Token Processing + +During data processing, the library: + +1. Wraps both `content` and `reasoning_content` with special unmask tokens (`<|UNMASK_BEGIN|>`, `<|UNMASK_END|>`, `<|UNMASK_REASONING_BEGIN|>`, `<|UNMASK_REASONING_END|>`) +2. Applies the chat template to the combined message content +3. Processes the tokenized sequence to create appropriate labels for training +4. Removes the special unmask tokens from the final training data + +### Validation + +The library validates that: + +- Both `content` and `reasoning_content` must be strings if present +- Special unmask tokens are properly processed and removed +- The final training data contains no residual unmask tokens + +### Error Handling + +Common errors and their meanings: + +- `"unmasking non-string data types is currently unsupported"`: The `content` field contains non-string data +- `"received an entry for reasoning_content which was not a string"`: The `reasoning_content` field contains non-string data + +## Integration with Existing Features + +### Unmasking Policies + +The `reasoning_content` field respects all existing unmasking policies: + +- When `unmask=true` is set on a sample, both fields are unmasked for non-system roles +- When `unmask=false` (default), only assistant role messages are unmasked +- Custom unmask role configurations work with both fields + +### Chat Templates + +The `reasoning_content` is unsupported by the legacy chat templates and will not be rendered. + +### Backward Compatibility + +The feature is fully backward compatible: + +- Existing datasets without `reasoning_content` continue to work unchanged +- All existing training configurations and arguments remain valid + +## Testing + +The library includes comprehensive tests for reasoning content functionality: + +- Unit tests for message wrapping and processing +- Integration tests with real tokenizers +- Validation tests for error conditions +- Backward compatibility tests + +## Important Notes + +### Automatic Processing Behavior + +1. **Always processed when present**: If `reasoning_content` exists in a message, it will always be processed and unmasked as long as the message role is targeted for unmasking. This ensures that reasoning traces are properly included in the training data without requiring additional configuration. + +2. **DeepSeek R1 and Qwen3 compatibility**: Models using the DeepSeek R1 thought processor (such as Qwen3) **must** supply their thinking traces in the `reasoning_content` field to be processed correctly. Failure to do so may result in improper handling of reasoning tokens and suboptimal training performance. + +3. **Separate token handling**: The library uses distinct unmask tokens for reasoning content (`<|UNMASK_REASONING_BEGIN|>` and `<|UNMASK_REASONING_END|>`) versus regular content (`<|UNMASK_BEGIN|>` and `<|UNMASK_END|>`), allowing for proper differentiation during training. + +## Best Practices + +1. **Consistent Usage**: When applicable, use `reasoning_content` consistently within a dataset for best results +2. **Clear Separation**: Keep reasoning traces separate from final outputs for clarity +3. **Template Compatibility**: Ensure your chat template properly handles both fields +4. **Validation**: Test your data processing pipeline with small samples before full training + +## Migration Guide + +To add reasoning content support to existing datasets: + +1. Add `reasoning_content` fields to relevant messages +2. Ensure content is in string format +3. Test with a small sample using the data processing pipeline +4. Verify that unmask tokens are properly processed + +No changes to training arguments or configuration are required. diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 0f283ded..3cee19f2 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -26,6 +26,8 @@ MASK_TOKEN = "<|MASK|>" UNMASK_BEGIN_TOKEN = "<|UNMASK_BEGIN|>" UNMASK_END_TOKEN = "<|UNMASK_END|>" +UNMASK_REASONING_BEGIN_TOKEN = "<|UNMASK_REASONING_BEGIN|>" +UNMASK_REASONING_END_TOKEN = "<|UNMASK_REASONING_END|>" logger = logging.getLogger(__name__) @@ -439,26 +441,74 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs): def wrap_masked_messages( - msgs: t.List[Message], unmask_roles: t.List[str] + msgs: t.List[Message], + unmask_roles: t.List[str], + enable_reasoning_content: bool = False, ) -> t.List[Message]: """ Given a list of messages and a set of roles we want to unmask, return - a list with the matching messages wrapped with `<|UNMASK_BEGIN|>` and `<|UNMASK_END|>` tokens - wrapped around the `message.content` field. + a list with the matching messages wrapped with unmask tokens. Args: msgs (List[Message]): List of messages we want to wrap with unmask tokens. unmask_roles (List[str]): The roles whose messages we should wrap. + enable_reasoning_content (bool): Whether to wrap reasoning_content fields. + When True, reasoning_content is wrapped with UNMASK_REASONING_BEGIN/END tokens. + When False, reasoning_content is left unchanged. Returns: List[Message]: The resultant list with all appropriate messages wrapped. + + Note: + The `content` field is wrapped with UNMASK_BEGIN/END tokens. + The `reasoning_content` field (if present and enable_reasoning_content=True) + is wrapped with UNMASK_REASONING_BEGIN/END tokens. """ new_msgs: t.List[Message] = [] for msg in msgs: - content = msg["content"] - if msg["role"] in unmask_roles: - content = UNMASK_BEGIN_TOKEN + content + UNMASK_END_TOKEN - new_msgs.append({"role": msg["role"], "content": content}) + if msg["role"] not in unmask_roles: + # do nothing + new_msgs += [msg] + continue + + # here, we need to be on the lookout for both string and non-string + # entries (e.g. other content types, or pure reasoning traces) + interesting_fields = ["content", "reasoning_content"] + new_msg = {k: v for k, v in msg.items() if k not in interesting_fields} + + # what's left to add then is content or reasoning_content + content = msg.get("content", None) + reasoning_content = msg.get("reasoning_content", None) + + # we handle these conditionally since these may become optional fields in the future. + if content is not None: + if not isinstance(content, str): + raise ValueError( + "Error: unmasking non-string data types is currently unsupported. " + ) + new_msg["content"] = UNMASK_BEGIN_TOKEN + content + UNMASK_END_TOKEN + + if reasoning_content is not None: + if enable_reasoning_content: + if not isinstance(reasoning_content, str): + raise ValueError( + "Error: received an entry for `reasoning_content` which was not a string. " + "Non-string datatypes for this field are currently unsupported, if this is intentional please raise an issue." + ) + + new_msg["reasoning_content"] = ( + UNMASK_REASONING_BEGIN_TOKEN + + reasoning_content + + UNMASK_REASONING_END_TOKEN + ) + else: + # When not enabled, pass through unchanged + new_msg["reasoning_content"] = reasoning_content + + # MyPy wants to be very specific about types, but new_msg may contain + # valid fields in each message which are hard to account for ahead of time. + new_msgs += [new_msg] # type: ignore + return new_msgs @@ -468,23 +518,22 @@ def unmask_messages( unmask_roles: t.List[str], ) -> ProcessedMessagesData: """ - Algorithm to unmask messages with any arbitrary Tokenizer, provided the following - conditions are satisfied: - - 1. A chat template has been set on the tokenizer - 2. The tokenizer has either an end-of-sequence, or a padding token that can be used as a fallback. - + Algorithm to unmask messages with any arbitrary Tokenizer, with support for + reasoning content. The algorithm handles both regular content and reasoning + content fields, merging adjacent unmask regions as needed. The algorithm works like this: - 1. Wrap all messages with `<|UNMASK_BEGIN|>` and `<|UNMASK_END|>` special tokens for all roles in `unmask_roles` - 2. Apply the chat template on the resultant messages - 3. Add all IDs seen into input_ids, and only those ids found within the special unmask boundary tokens into labels + 1. Wrap messages with unmask tokens: + - content: wrapped with UNMASK_BEGIN/END tokens + - reasoning_content: wrapped with UNMASK_REASONING_BEGIN/END tokens + 2. Apply the chat template on the wrapped messages + 3. Process the token sequence to identify and merge unmask regions + 4. Generate labels based on the unmask regions **Note**: If a tokenizer has an end-of-sequence token, it is only ever unmasked for the `assistant` role. This helps prevent confusion for the model when learning to predict the next token. - Please reach out to the instructlab/training maintainers if you need this behavior changed. Args: msgs (List[Message]): A list of messages. @@ -495,97 +544,213 @@ def unmask_messages( Returns: Result (ProcessedMessagesData): Dict with the resulting `input_ids`, `labels`, and `len` """ - msgs_with_unmasking = wrap_masked_messages(msgs, unmask_roles) - input_ids = tokenizer.apply_chat_template(msgs_with_unmasking) + # Check if any messages have reasoning_content that we need to handle + has_reasoning = any( + msg.get("reasoning_content") is not None + for msg in msgs + if msg["role"] in unmask_roles + ) - # get the order of unmasked roles - unmask_roles_order = iter( - [msg["role"] for msg in msgs_with_unmasking if msg["role"] in unmask_roles] + # TODO(osilkin): Here we assume that we will always unmask reasoning content, + # in the future we can make this configurable. + msgs_with_unmasking = wrap_masked_messages( + msgs, unmask_roles, enable_reasoning_content=has_reasoning ) - # get token ids + # Create a mapping of message index to expected regions + message_regions_map = {} + for idx, msg in enumerate(msgs_with_unmasking): + if msg["role"] in unmask_roles: + regions = [] + if has_reasoning and msg.get("reasoning_content") is not None: + regions.append("reasoning") + if msg.get("content") is not None: + regions.append("content") + if regions: + message_regions_map[idx] = regions + + input_ids = tokenizer.apply_chat_template(msgs_with_unmasking) + + # Get token IDs for all unmask tokens unmask_begin_token_id = tokenizer.encode( UNMASK_BEGIN_TOKEN, add_special_tokens=False )[0] unmask_end_token_id = tokenizer.encode(UNMASK_END_TOKEN, add_special_tokens=False)[ 0 ] + unmask_reasoning_begin_token_id = tokenizer.encode( + UNMASK_REASONING_BEGIN_TOKEN, add_special_tokens=False + )[0] + unmask_reasoning_end_token_id = tokenizer.encode( + UNMASK_REASONING_END_TOKEN, add_special_tokens=False + )[0] + eos_token_id = None if tokenizer.eos_token is not None: eos_token_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[ 0 ] - final_input_ids = [] - final_labels = [] + # First pass: identify unmask regions and their types + unmask_regions = [] i = 0 - unmasking = False - role_being_actively_unmasked = None - - # pylint: disable=too-many-nested-blocks while i < len(input_ids): tok = input_ids[i] - if unmasking: - if tok == unmask_begin_token_id: - raise ValueError( - f'encountered a "{UNMASK_BEGIN_TOKEN}" token while already unmasking. This should never happen, pleas contact the training maintainers.' - ) - if tok == unmask_end_token_id: - # we need to just make sure that we unmask the EOS token for the assistant role - if ( - eos_token_id is not None - and role_being_actively_unmasked == "assistant" - ): - # TODO(osilkin): clean up this portion so that we don't run into race conditions or other bugs - i += 1 - while i < len(input_ids): - final_input_ids.append(input_ids[i]) - final_labels.append(input_ids[i]) - if input_ids[i] == eos_token_id: - break - i += 1 - unmasking = False - role_being_actively_unmasked = None + # Check for orphaned end tokens + if tok == unmask_end_token_id: + raise ValueError( + f'encountered an "{UNMASK_END_TOKEN}" token while not unmasking. This should never happen, please contact the training maintainers.' + ) + + if tok == unmask_reasoning_end_token_id: + raise ValueError( + f'encountered an "{UNMASK_REASONING_END_TOKEN}" token while not unmasking. This should never happen, please contact the training maintainers.' + ) + + # Check for unmask begin tokens + if tok == unmask_begin_token_id: + # Find the matching end token + j = i + 1 + while j < len(input_ids) and input_ids[j] != unmask_end_token_id: + # Check for nested begin tokens + if input_ids[j] == unmask_begin_token_id: + raise ValueError( + f'encountered a "{UNMASK_BEGIN_TOKEN}" token while already unmasking. This should never happen, please contact the training maintainers.' + ) + j += 1 + if j < len(input_ids): + unmask_regions.append((i, j, "content")) + i = j else: - final_input_ids.append(tok) - final_labels.append(tok) - else: - if tok == unmask_end_token_id: - raise ValueError( - f'encountered an "{UNMASK_END_TOKEN}" token while not unmasking. This should never happen, please contact the training maintainers.' + raise RuntimeError( + "suffered a critical failure: unmasking finished but not all messages were processed. Please report this!" ) - - if tok == unmask_begin_token_id: - unmasking = True - role_being_actively_unmasked = next(unmask_roles_order, None) + elif tok == unmask_reasoning_begin_token_id: + # Find the matching end token + j = i + 1 + while j < len(input_ids) and input_ids[j] != unmask_reasoning_end_token_id: + # Check for nested begin tokens + if input_ids[j] == unmask_reasoning_begin_token_id: + raise ValueError( + f'encountered a "{UNMASK_REASONING_BEGIN_TOKEN}" token while already unmasking. This should never happen, please contact the training maintainers.' + ) + j += 1 + if j < len(input_ids): + unmask_regions.append((i, j, "reasoning")) + i = j else: - final_input_ids.append(tok) - final_labels.append(-100) + raise RuntimeError( + "suffered a critical failure: unmasking finished but not all messages were processed. Please report this!" + ) + i += 1 - # validation logic - if unmask_begin_token_id in final_input_ids: - raise ValueError( - f"{UNMASK_BEGIN_TOKEN} token found in final_input_ids. This should never happen, please contact the training maintainers." - ) - if unmask_begin_token_id in final_labels: - raise ValueError( - f"{UNMASK_BEGIN_TOKEN} token found in final_labels. This should never happen, please contact the training maintainers." - ) - if unmask_end_token_id in final_input_ids: - raise ValueError( - f"{UNMASK_END_TOKEN} token found in final_input_ids. This should never happen, please contact the training maintainers." - ) - if unmask_end_token_id in final_labels: - raise ValueError( - f"{UNMASK_END_TOKEN} token found in final_labels. This should never happen, please contact the training maintainers." - ) + # Group regions by message and merge if they belong to the same message + # First, we need to map regions back to their source messages + region_to_message_map = {} + region_idx = 0 + for msg_idx, expected_regions in message_regions_map.items(): + for expected_type in expected_regions: + # Find the next region of the expected type + while region_idx < len(unmask_regions): + if unmask_regions[region_idx][2] == expected_type: + region_to_message_map[region_idx] = ( + msg_idx, + msgs_with_unmasking[msg_idx]["role"], + ) + region_idx += 1 + break + region_idx += 1 + + # Now merge regions that belong to the same message + merged_regions: list[tuple[int, int, str, str | None]] = [] + i = 0 + while i < len(unmask_regions): + start, end, region_type = unmask_regions[i] + msg_info = region_to_message_map.get(i) - if role_being_actively_unmasked is not None: - raise RuntimeError( - "suffered a critical failure: unmasking finished but not all messages were processed. Please report this!" - ) + if msg_info is None: + # This shouldn't happen, but if it does, keep the region as-is + merged_regions.append((start, end, region_type, None)) + i += 1 + continue + + msg_idx, role = msg_info + + # Check if the next region belongs to the same message + if i + 1 < len(unmask_regions) and (i + 1) in region_to_message_map: + next_msg_idx, _ = region_to_message_map[i + 1] + if msg_idx == next_msg_idx: + # Same message - merge the regions + _, next_end, _ = unmask_regions[i + 1] + merged_regions.append((start, next_end, "merged", role)) + i += 2 + continue + + # Not merged - keep as is + merged_regions.append((start, end, region_type, role)) + i += 1 + + # Build the final token sequences + final_input_ids = [] + final_labels = [] + unmask_tokens = { + unmask_begin_token_id, + unmask_end_token_id, + unmask_reasoning_begin_token_id, + unmask_reasoning_end_token_id, + } + + # Track which tokens to unmask based on regions + tokens_to_unmask = set() + for start, end, _, region_role in merged_regions: + for idx in range(start + 1, end): + if input_ids[idx] not in unmask_tokens: + tokens_to_unmask.add(idx) + + # For assistant messages, also unmask tokens after the region until EOS + if eos_token_id is not None and region_role == "assistant": + # Look for EOS token after the region + j = end + 1 + while j < len(input_ids): + if input_ids[j] == eos_token_id: + # Unmask everything from end of region to EOS (inclusive) + for k in range(end + 1, j + 1): + tokens_to_unmask.add(k) + break + # Stop if we encounter another unmask region start + if input_ids[j] in { + unmask_begin_token_id, + unmask_reasoning_begin_token_id, + }: + break + j += 1 + + # Generate final sequences + for i, tok in enumerate(input_ids): + if tok not in unmask_tokens: + final_input_ids.append(tok) + if i in tokens_to_unmask: + final_labels.append(tok) + else: + final_labels.append(-100) + + # Validation logic + for tok_id, tok_name in [ + (unmask_begin_token_id, UNMASK_BEGIN_TOKEN), + (unmask_end_token_id, UNMASK_END_TOKEN), + (unmask_reasoning_begin_token_id, UNMASK_REASONING_BEGIN_TOKEN), + (unmask_reasoning_end_token_id, UNMASK_REASONING_END_TOKEN), + ]: + if tok_id in final_input_ids: + raise ValueError( + f"{tok_name} token found in final_input_ids. This should never happen, please contact the training maintainers." + ) + if tok_id in final_labels: + raise ValueError( + f"{tok_name} token found in final_labels. This should never happen, please contact the training maintainers." + ) if len(final_input_ids) != len(final_labels): raise RuntimeError( @@ -906,6 +1071,8 @@ def configure_tokenizer(model_path: str) -> PreTrainedTokenizer: "additional_special_tokens": [ UNMASK_BEGIN_TOKEN, UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, MASK_TOKEN, ] } diff --git a/src/instructlab/training/type_definitions.py b/src/instructlab/training/type_definitions.py index 1240c76e..0b8866a3 100644 --- a/src/instructlab/training/type_definitions.py +++ b/src/instructlab/training/type_definitions.py @@ -10,14 +10,35 @@ # Standard import typing as t +# For Python 3.8+ compatibility +try: + # Standard + from typing import NotRequired, Required +except ImportError: + try: + # Third Party + from typing_extensions import NotRequired, Required + except ImportError: + # Fallback for older Python versions + Required = t.Annotated + NotRequired = t.Annotated + class Message(t.TypedDict): """ Format of a single message sample. + + Fields: + content: The main content of the message. + role: The role of the message sender (e.g., "user", "assistant", "system"). + reasoning_content: Optional reasoning trace or thinking process associated with the message. + This field is particularly useful for training reasoning-capable models + that can separate their thinking process from their final output. """ - content: str - role: str + content: Required[str] + role: Required[str] + reasoning_content: NotRequired[str] class ProcessedMessagesData(t.TypedDict): diff --git a/tests/unit/test_data_process.py b/tests/unit/test_data_process.py new file mode 100644 index 00000000..ba0bc4cc --- /dev/null +++ b/tests/unit/test_data_process.py @@ -0,0 +1,994 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from unittest.mock import Mock, patch +import tempfile +import typing as t +import unittest + +try: + # Third Party + import pytest + + PYTEST_AVAILABLE = True +except ImportError: + PYTEST_AVAILABLE = False + +try: + # Third Party + from transformers import AutoTokenizer, PreTrainedTokenizer + + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + PreTrainedTokenizer = None + +# First Party +from instructlab.training.data_process import ( + MASK_TOKEN, + UNMASK_BEGIN_TOKEN, + UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, + unmask_messages, + unmask_sample, + wrap_masked_messages, +) +from instructlab.training.type_definitions import Message, ProcessedMessagesData + + +class TestComprehensiveUnmasking(unittest.TestCase): + """Comprehensive test suite for unmasking behavior across various scenarios.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock tokenizer for basic tests + if TRANSFORMERS_AVAILABLE: + self.mock_tokenizer = Mock(spec=PreTrainedTokenizer) + else: + self.mock_tokenizer = Mock() + + # Set up token IDs for unmask tokens + self.unmask_begin_id = 1001 + self.unmask_end_id = 1002 + self.eos_id = 1003 + self.think_id = 1004 + self.end_think_id = 1005 + + def mock_encode_special(text, add_special_tokens=False): + if text == UNMASK_BEGIN_TOKEN: + return [self.unmask_begin_id] + elif text == UNMASK_END_TOKEN: + return [self.unmask_end_id] + elif text == "": + return [self.eos_id] + elif text == "": + return [self.think_id] + elif text == "": + return [self.end_think_id] + else: + # Simple hash-based encoding for text + return [hash(text) % 1000 + 100 for _ in text.split()] + + self.mock_tokenizer.encode.side_effect = mock_encode_special + self.mock_tokenizer.decode.side_effect = lambda tokens: " ".join( + [f"token_{t}" for t in tokens] + ) + self.mock_tokenizer.apply_chat_template.side_effect = ( + self._mock_apply_chat_template + ) + self.mock_tokenizer.eos_token = "" + + def _mock_apply_chat_template( + self, + messages: t.List[Message], + tokenize: bool = True, + add_special_tokens: bool = True, + ) -> t.Union[str, t.List[int]]: + """Mock implementation of apply_chat_template.""" + template_tokens = [] + + for msg in messages: + # Add role tokens + role_tokens = [hash(f"<|{msg['role']}|>") % 1000 + 2000] + template_tokens.extend(role_tokens) + + # Add content tokens + if "content" in msg and msg["content"]: + content_tokens = [ + hash(msg["content"]) % 1000 + 3000 for _ in msg["content"].split() + ] + template_tokens.extend(content_tokens) + + # Add reasoning content tokens + if "reasoning_content" in msg and msg["reasoning_content"]: + reasoning_tokens = [ + hash(msg["reasoning_content"]) % 1000 + 4000 + for _ in msg["reasoning_content"].split() + ] + template_tokens.extend(reasoning_tokens) + + if tokenize: + return template_tokens + else: + return " ".join([f"token_{t}" for t in template_tokens]) + + def test_single_turn_assistant_only_content(self): + """Test basic single-turn conversation with assistant content only.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + # Verify unmask tokens are not in final output + self.assertNotIn(self.unmask_begin_id, result["input_ids"]) + self.assertNotIn(self.unmask_end_id, result["input_ids"]) + self.assertNotIn(self.unmask_begin_id, result["labels"]) + self.assertNotIn(self.unmask_end_id, result["labels"]) + + def test_single_turn_assistant_only_reasoning(self): + """Test single-turn with assistant reasoning_content only.""" + messages = [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "reasoning_content": "I need to add 2 and 2 together.", + }, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_single_turn_assistant_both_content_and_reasoning(self): + """Test single-turn with both content and reasoning_content.""" + messages = [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": "The answer is 4.", + "reasoning_content": "I need to add 2 and 2 together.", + }, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_multi_turn_conversation_basic(self): + """Test basic multi-turn conversation.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi! How can I help?"}, + {"role": "user", "content": "What's the weather like?"}, + { + "role": "assistant", + "content": "I don't have access to current weather data.", + }, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_multi_turn_with_reasoning_content(self): + """Test multi-turn conversation with reasoning content in multiple turns.""" + messages = [ + {"role": "user", "content": "What is 5*7?"}, + { + "role": "assistant", + "content": "35", + "reasoning_content": "5 times 7 equals 35", + }, + {"role": "user", "content": "What about 6*8?"}, + { + "role": "assistant", + "content": "48", + "reasoning_content": "6 times 8 equals 48", + }, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_multi_turn_mixed_content_types(self): + """Test multi-turn with mixed content types (some with reasoning, some without).""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, # No reasoning + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": "The answer is 4.", + "reasoning_content": "I need to add 2 and 2.", + }, # Both content and reasoning + {"role": "user", "content": "Think about the meaning of life."}, + { + "role": "assistant", + "reasoning_content": "This is a deep philosophical question.", + }, # Reasoning only + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_system_user_assistant_conversation(self): + """Test conversation with system, user, and assistant roles.""" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is AI?"}, + { + "role": "assistant", + "content": "AI stands for Artificial Intelligence.", + "reasoning_content": "This is a straightforward definition question.", + }, + ] + + # Test unmasking only assistant + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + self.assertIsInstance(result, dict) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_multiple_unmask_roles(self): + """Test unmasking multiple roles.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "user", + "content": "Question about math?", + "reasoning_content": "I'm asking about mathematics.", + }, + { + "role": "assistant", + "content": "Sure, I can help with math.", + "reasoning_content": "Math questions are common.", + }, + ] + + # Test unmasking both user and assistant + result = unmask_messages(messages, self.mock_tokenizer, ["user", "assistant"]) + self.assertIsInstance(result, dict) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_reasoning_only_conversation(self): + """Test conversation where all assistant messages have only reasoning_content.""" + messages = [ + {"role": "user", "content": "Think step by step."}, + {"role": "assistant", "reasoning_content": "Step 1: Consider the problem."}, + {"role": "user", "content": "Continue."}, + {"role": "assistant", "reasoning_content": "Step 2: Analyze the solution."}, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + self.assertIsInstance(result, dict) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_empty_content_edge_cases(self): + """Test edge cases with empty content.""" + messages = [ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "Empty content case", + }, + {"role": "user", "content": "Continue"}, + {"role": "assistant", "content": "Response", "reasoning_content": ""}, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + self.assertIsInstance(result, dict) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_consecutive_assistant_messages(self): + """Test consecutive assistant messages (simulating the Qwen scenario).""" + messages = [ + {"role": "user", "content": "First question"}, + { + "role": "assistant", + "content": "First response A", + "reasoning_content": "Reasoning A", + }, + {"role": "user", "content": "Second question"}, + { + "role": "assistant", + "content": "Second response B", + "reasoning_content": "Reasoning B", + }, + { + "role": "assistant", + "content": "Third response C", + "reasoning_content": "Reasoning C", + }, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + self.assertIsInstance(result, dict) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_long_multi_turn_conversation(self): + """Test long multi-turn conversation with various content types.""" + messages = [] + for i in range(10): + messages.append({"role": "user", "content": f"User message {i}"}) + + if i % 3 == 0: + # Content only + messages.append( + {"role": "assistant", "content": f"Assistant response {i}"} + ) + elif i % 3 == 1: + # Reasoning only + messages.append( + {"role": "assistant", "reasoning_content": f"Reasoning {i}"} + ) + else: + # Both content and reasoning + messages.append( + { + "role": "assistant", + "content": f"Response {i}", + "reasoning_content": f"Reasoning {i}", + } + ) + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + self.assertIsInstance(result, dict) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_unmask_sample_function(self): + """Test the unmask_sample function with various scenarios.""" + sample_scenarios = [ + # Basic conversation + { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + }, + # With reasoning content + { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": "4", + "reasoning_content": "2 plus 2 equals 4", + }, + ] + }, + # With unmask flag + { + "messages": [ + {"role": "user", "content": "Question"}, + {"role": "assistant", "content": "Answer"}, + ], + "unmask": True, + }, + # Multi-turn with system + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Question"}, + {"role": "assistant", "content": "Answer"}, + ] + }, + ] + + for i, sample in enumerate(sample_scenarios): + with self.subTest(scenario=i): + result = unmask_sample(sample, self.mock_tokenizer) + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_wrap_masked_messages_comprehensive(self): + """Test wrap_masked_messages with comprehensive scenarios.""" + test_cases = [ + # Single role, content only + { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ], + "unmask_roles": ["assistant"], + "expected_wrapped_count": 1, + }, + # Single role, reasoning only + { + "messages": [ + {"role": "user", "content": "Think"}, + {"role": "assistant", "reasoning_content": "Thinking..."}, + ], + "unmask_roles": ["assistant"], + "expected_wrapped_count": 1, + }, + # Single role, both content types + { + "messages": [ + {"role": "user", "content": "Question"}, + { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Thinking", + }, + ], + "unmask_roles": ["assistant"], + "expected_wrapped_count": 2, # Both content and reasoning_content wrapped + }, + # Multiple roles + { + "messages": [ + {"role": "system", "content": "System message"}, + { + "role": "user", + "content": "User question", + "reasoning_content": "User thinking", + }, + { + "role": "assistant", + "content": "Assistant answer", + "reasoning_content": "Assistant thinking", + }, + ], + "unmask_roles": ["user", "assistant"], + "expected_wrapped_count": 4, # 2 messages × 2 fields each + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(case=i): + result = wrap_masked_messages( + case["messages"], case["unmask_roles"], True + ) + + # Count wrapped fields + wrapped_count = 0 + for msg in result: + if msg["role"] in case["unmask_roles"]: + if msg.get("content") and UNMASK_BEGIN_TOKEN in msg["content"]: + wrapped_count += 1 + if ( + msg.get("reasoning_content") + and UNMASK_REASONING_BEGIN_TOKEN in msg["reasoning_content"] + ): + wrapped_count += 1 + + self.assertEqual(wrapped_count, case["expected_wrapped_count"]) + + def test_error_conditions(self): + """Test various error conditions.""" + # Test non-string content + with self.assertRaises(ValueError): + wrap_masked_messages( + [{"role": "assistant", "content": ["not", "a", "string"]}], + ["assistant"], + True, + ) + + # Test non-string reasoning_content + with self.assertRaises(ValueError): + wrap_masked_messages( + [{"role": "assistant", "reasoning_content": {"not": "a string"}}], + ["assistant"], + True, + ) + + def test_think_tag_handling(self): + """Test that and tags are properly handled.""" + # This is a basic test since the mock tokenizer handles think tags + messages = [ + {"role": "user", "content": "Question with thinking"}, + { + "role": "assistant", + "content": "Answer with more thinking", + }, + ] + + result = unmask_messages(messages, self.mock_tokenizer, ["assistant"]) + self.assertIsInstance(result, dict) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + +class TestReasoningContentSupport(unittest.TestCase): + """Test suite for reasoning_content field support in data processing.""" + + def setUp(self): + """Set up test fixtures.""" + # Mock tokenizer for basic tests + if TRANSFORMERS_AVAILABLE: + self.mock_tokenizer = Mock(spec=PreTrainedTokenizer) + else: + self.mock_tokenizer = Mock() + self.mock_tokenizer.encode.side_effect = ( + lambda text, add_special_tokens=False: [ + hash(text) % 1000 for _ in text.split() + ] + ) + self.mock_tokenizer.decode.side_effect = lambda tokens: " ".join( + [f"token_{t}" for t in tokens] + ) + self.mock_tokenizer.apply_chat_template.side_effect = ( + self._mock_apply_chat_template + ) + self.mock_tokenizer.eos_token = "" + + # Set up token IDs for unmask tokens + self.unmask_begin_id = 1001 + self.unmask_end_id = 1002 + self.unmask_reasoning_begin_id = 1004 + self.unmask_reasoning_end_id = 1005 + self.eos_id = 1003 + + def mock_encode_special(text, add_special_tokens=False): + if text == UNMASK_BEGIN_TOKEN: + return [self.unmask_begin_id] + elif text == UNMASK_END_TOKEN: + return [self.unmask_end_id] + elif text == UNMASK_REASONING_BEGIN_TOKEN: + return [self.unmask_reasoning_begin_id] + elif text == UNMASK_REASONING_END_TOKEN: + return [self.unmask_reasoning_end_id] + elif text == "": + return [self.eos_id] + else: + return [hash(text) % 1000] + + self.mock_tokenizer.encode.side_effect = mock_encode_special + + def _mock_apply_chat_template( + self, + messages: t.List[Message], + tokenize: bool = True, + add_special_tokens: bool = True, + ) -> t.Union[str, t.List[int]]: + """Mock implementation of apply_chat_template.""" + template_str = "" + for msg in messages: + template_str += f"<|{msg['role']}|>\n" + if "content" in msg: + template_str += msg["content"] + if "reasoning_content" in msg: + template_str += msg["reasoning_content"] + template_str += "\n" + + if tokenize: + return [hash(template_str) % 1000 for _ in range(len(template_str.split()))] + else: + return template_str + + def test_wrap_masked_messages_with_reasoning_content(self): + """Test that wrap_masked_messages correctly wraps both content and reasoning_content.""" + messages = [ + { + "role": "user", + "content": "What is 2+2?", + }, + { + "role": "assistant", + "content": "The answer is 4.", + "reasoning_content": "I need to add 2 and 2 together. 2 + 2 = 4.", + }, + ] + + unmask_roles = ["assistant"] + result = wrap_masked_messages(messages, unmask_roles, True) + + # Check that user message is unchanged + self.assertEqual(result[0]["role"], "user") + self.assertEqual(result[0]["content"], "What is 2+2?") + self.assertNotIn("reasoning_content", result[0]) + + # Check that assistant message has both fields wrapped + self.assertEqual(result[1]["role"], "assistant") + self.assertEqual( + result[1]["content"], + f"{UNMASK_BEGIN_TOKEN}The answer is 4.{UNMASK_END_TOKEN}", + ) + self.assertEqual( + result[1]["reasoning_content"], + f"{UNMASK_REASONING_BEGIN_TOKEN}I need to add 2 and 2 together. 2 + 2 = 4.{UNMASK_REASONING_END_TOKEN}", + ) + + def test_wrap_masked_messages_content_only(self): + """Test that wrap_masked_messages works with messages that only have content.""" + messages = [ + { + "role": "user", + "content": "Hello!", + }, + { + "role": "assistant", + "content": "Hi there!", + }, + ] + + unmask_roles = ["assistant"] + result = wrap_masked_messages(messages, unmask_roles, True) + + # Check that user message is unchanged + self.assertEqual(result[0]["role"], "user") + self.assertEqual(result[0]["content"], "Hello!") + + # Check that assistant message has content wrapped + self.assertEqual(result[1]["role"], "assistant") + self.assertEqual( + result[1]["content"], + f"{UNMASK_BEGIN_TOKEN}Hi there!{UNMASK_END_TOKEN}", + ) + self.assertNotIn("reasoning_content", result[1]) + + def test_wrap_masked_messages_reasoning_content_only(self): + """Test that wrap_masked_messages works with messages that only have reasoning_content.""" + messages = [ + { + "role": "user", + "content": "Think step by step.", + }, + { + "role": "assistant", + "reasoning_content": "Let me think about this step by step...", + }, + ] + + unmask_roles = ["assistant"] + result = wrap_masked_messages(messages, unmask_roles, True) + + # Check that user message is unchanged + self.assertEqual(result[0]["role"], "user") + self.assertEqual(result[0]["content"], "Think step by step.") + + # Check that assistant message has reasoning_content wrapped + self.assertEqual(result[1]["role"], "assistant") + self.assertEqual( + result[1]["reasoning_content"], + f"{UNMASK_REASONING_BEGIN_TOKEN}Let me think about this step by step...{UNMASK_REASONING_END_TOKEN}", + ) + self.assertNotIn("content", result[1]) + + def test_wrap_masked_messages_multiple_unmask_roles(self): + """Test that wrap_masked_messages works with multiple roles to unmask.""" + messages = [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "What is the capital of France?", + "reasoning_content": "This is a geography question about France.", + }, + { + "role": "assistant", + "content": "The capital of France is Paris.", + "reasoning_content": "This is a straightforward geography question.", + }, + ] + + unmask_roles = ["user", "assistant"] + result = wrap_masked_messages(messages, unmask_roles, True) + + # Check that system message is unchanged + self.assertEqual(result[0]["role"], "system") + self.assertEqual(result[0]["content"], "You are a helpful assistant.") + + # Check that user message has both fields wrapped + self.assertEqual(result[1]["role"], "user") + self.assertEqual( + result[1]["content"], + f"{UNMASK_BEGIN_TOKEN}What is the capital of France?{UNMASK_END_TOKEN}", + ) + self.assertEqual( + result[1]["reasoning_content"], + f"{UNMASK_REASONING_BEGIN_TOKEN}This is a geography question about France.{UNMASK_REASONING_END_TOKEN}", + ) + + # Check that assistant message has both fields wrapped + self.assertEqual(result[2]["role"], "assistant") + self.assertEqual( + result[2]["content"], + f"{UNMASK_BEGIN_TOKEN}The capital of France is Paris.{UNMASK_END_TOKEN}", + ) + self.assertEqual( + result[2]["reasoning_content"], + f"{UNMASK_REASONING_BEGIN_TOKEN}This is a straightforward geography question.{UNMASK_REASONING_END_TOKEN}", + ) + + def test_wrap_masked_messages_non_string_content_error(self): + """Test that wrap_masked_messages raises error for non-string content.""" + messages = [ + { + "role": "assistant", + "content": ["This", "is", "not", "a", "string"], + } + ] + + unmask_roles = ["assistant"] + + with self.assertRaises(ValueError) as context: + wrap_masked_messages(messages, unmask_roles, True) + + self.assertIn( + "unmasking non-string data types is currently unsupported", + str(context.exception), + ) + + def test_wrap_masked_messages_non_string_reasoning_content_error(self): + """Test that wrap_masked_messages raises error for non-string reasoning_content.""" + messages = [ + { + "role": "assistant", + "content": "Valid content", + "reasoning_content": {"thinking": "This is not a string"}, + } + ] + + unmask_roles = ["assistant"] + + with self.assertRaises(ValueError) as context: + wrap_masked_messages(messages, unmask_roles, True) + + self.assertIn( + "received an entry for `reasoning_content` which was not a string", + str(context.exception), + ) + + def test_unmask_messages_with_reasoning_content(self): + """Test that unmask_messages correctly processes reasoning_content.""" + # This is a complex integration test, so we'll test it with a real tokenizer in the integration tests + # For unit testing, we just verify that the wrap_masked_messages function properly handles reasoning_content + messages = [ + { + "role": "user", + "content": "What is 5*7?", + }, + { + "role": "assistant", + "content": "35", + "reasoning_content": "5 times 7 equals 35", + }, + ] + + unmask_roles = ["assistant"] + + # Test that wrap_masked_messages works correctly with reasoning_content + wrapped = wrap_masked_messages(messages, unmask_roles, True) + + # Verify that both content and reasoning_content are wrapped + self.assertIn(UNMASK_BEGIN_TOKEN, wrapped[1]["content"]) + self.assertIn(UNMASK_END_TOKEN, wrapped[1]["content"]) + self.assertIn(UNMASK_REASONING_BEGIN_TOKEN, wrapped[1]["reasoning_content"]) + self.assertIn(UNMASK_REASONING_END_TOKEN, wrapped[1]["reasoning_content"]) + + # Verify the user message is unchanged + self.assertEqual(wrapped[0]["content"], "What is 5*7?") + self.assertNotIn("reasoning_content", wrapped[0]) + + def test_unmask_sample_with_reasoning_content(self): + """Test that unmask_sample correctly processes samples with reasoning_content.""" + sample = { + "messages": [ + { + "role": "user", + "content": "Explain photosynthesis.", + }, + { + "role": "assistant", + "content": "Photosynthesis is the process by which plants make food.", + "reasoning_content": "I need to explain photosynthesis in simple terms.", + }, + ] + } + + result = unmask_sample(sample, self.mock_tokenizer) + + # Check that result has the expected structure + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + + def test_unmask_sample_with_unmask_flag(self): + """Test that unmask_sample correctly handles the unmask flag.""" + sample = { + "messages": [ + { + "role": "user", + "content": "Hello", + }, + { + "role": "assistant", + "content": "Hi", + "reasoning_content": "Simple greeting", + }, + ], + "unmask": True, + } + + result = unmask_sample(sample, self.mock_tokenizer) + + # Check that result has the expected structure + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + + +@unittest.skipUnless(TRANSFORMERS_AVAILABLE, "transformers library not available") +class TestReasoningContentWithRealTokenizers(unittest.TestCase): + """Test reasoning_content functionality with real tokenizers.""" + + @unittest.skipUnless(PYTEST_AVAILABLE, "pytest not available") + def test_with_qwen_tokenizer(self): + """Test reasoning_content functionality with Qwen3-32B tokenizer.""" + try: + # Use a smaller Qwen model that's more readily available + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-32B") + except Exception as e: + self.skipTest(f"Qwen tokenizer not available: {e}") + + # Add the unmask tokens to the tokenizer + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + UNMASK_BEGIN_TOKEN, + UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, + MASK_TOKEN, + ] + } + ) + + messages = [ + { + "role": "user", + "content": "What is 2+2?", + }, + { + "role": "assistant", + "content": "4", + "reasoning_content": "I need to add 2 and 2, which equals 4.", + }, + ] + + # Test wrap_masked_messages + wrapped = wrap_masked_messages(messages, ["assistant"], True) + + # Verify that both content and reasoning_content are wrapped + self.assertIn(UNMASK_BEGIN_TOKEN, wrapped[1]["content"]) + self.assertIn(UNMASK_END_TOKEN, wrapped[1]["content"]) + self.assertIn(UNMASK_REASONING_BEGIN_TOKEN, wrapped[1]["reasoning_content"]) + self.assertIn(UNMASK_REASONING_END_TOKEN, wrapped[1]["reasoning_content"]) + + # Test unmask_messages + result = unmask_messages(messages, tokenizer, ["assistant"]) + + # Verify the result structure + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + + # Verify that special tokens are not in the final output + unmask_begin_id = tokenizer.encode( + UNMASK_BEGIN_TOKEN, add_special_tokens=False + )[0] + unmask_end_id = tokenizer.encode(UNMASK_END_TOKEN, add_special_tokens=False)[0] + + self.assertNotIn(unmask_begin_id, result["input_ids"]) + self.assertNotIn(unmask_end_id, result["input_ids"]) + + @unittest.skipUnless(PYTEST_AVAILABLE, "pytest not available") + def test_with_mistral_tokenizer(self): + """Test reasoning_content functionality with Mistral tokenizer.""" + try: + tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mistral-7B-Instruct-v0.1" + ) + except Exception as e: + self.skipTest(f"Mistral tokenizer not available: {e}") + + # Add the unmask tokens to the tokenizer + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + UNMASK_BEGIN_TOKEN, + UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, + MASK_TOKEN, + ] + } + ) + + messages = [ + { + "role": "user", + "content": "Calculate 5*6", + }, + { + "role": "assistant", + "content": "30", + "reasoning_content": "5 multiplied by 6 equals 30.", + }, + ] + + # Test the full pipeline + result = unmask_messages(messages, tokenizer, ["assistant"]) + + # Verify the result structure and content + self.assertIsInstance(result, dict) + self.assertIn("input_ids", result) + self.assertIn("labels", result) + self.assertIn("len", result) + self.assertGreater(len(result["input_ids"]), 0) + self.assertEqual(len(result["input_ids"]), len(result["labels"])) + + def test_edge_cases_with_reasoning_content(self): + """Test edge cases for reasoning_content functionality.""" + # Test empty reasoning_content + messages = [ + { + "role": "assistant", + "content": "Response", + "reasoning_content": "", + } + ] + + wrapped = wrap_masked_messages(messages, ["assistant"], True) + self.assertEqual( + wrapped[0]["reasoning_content"], + f"{UNMASK_REASONING_BEGIN_TOKEN}{UNMASK_REASONING_END_TOKEN}", + ) + + # Test only reasoning_content without content + messages = [ + { + "role": "assistant", + "reasoning_content": "Thinking process", + } + ] + + wrapped = wrap_masked_messages(messages, ["assistant"], True) + self.assertNotIn("content", wrapped[0]) + self.assertEqual( + wrapped[0]["reasoning_content"], + f"{UNMASK_REASONING_BEGIN_TOKEN}Thinking process{UNMASK_REASONING_END_TOKEN}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_reasoning_unmask.py b/tests/unit/test_reasoning_unmask.py new file mode 100644 index 00000000..848fd5f6 --- /dev/null +++ b/tests/unit/test_reasoning_unmask.py @@ -0,0 +1,520 @@ +#!/usr/bin/env python3 +""" +Comprehensive tests for reasoning content unmasking functionality. +Tests the merging of reasoning and content unmask regions as described +in the InstructLab training documentation. +""" + +# Standard +from unittest.mock import Mock + +# Third Party +from transformers import AutoTokenizer +import pytest + +# First Party +from instructlab.training.data_process import ( + UNMASK_BEGIN_TOKEN, + UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, + unmask_messages, + wrap_masked_messages, +) + + +class TestReasoningContentUnmasking: + """Test suite for reasoning content unmasking with region merging.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer with all necessary tokens.""" + tokenizer = Mock() + + # Mock token encoding + def mock_encode(text, add_special_tokens=False): + token_map = { + UNMASK_BEGIN_TOKEN: [1000], + UNMASK_END_TOKEN: [1001], + UNMASK_REASONING_BEGIN_TOKEN: [1002], + UNMASK_REASONING_END_TOKEN: [1003], + "<|endoftext|>": [50256], + "<|im_end|>": [50257], + "<|im_start|>": [50258], + "": [50259], + "": [50260], + "\n\n": [50261], + } + # For unknown text, return a hash-based token + return token_map.get(text, [hash(text) % 10000 + 100]) + + tokenizer.encode = mock_encode + tokenizer.eos_token = "<|im_end|>" + + return tokenizer + + def test_reasoning_content_merging_basic(self, mock_tokenizer): + """Test basic merging of reasoning and content regions.""" + messages = [ + {"role": "user", "content": "Where is Paris?"}, + { + "role": "assistant", + "content": "Paris is in Europe", + "reasoning_content": "Paris is the capital of France, France is in Europe", + }, + ] + + # Simulate Qwen/DeepSeek style template output: + # <|im_start|>user\nWhere is Paris?<|im_end|> + # <|im_start|>assistant\n\n[REASONING]\n\n\n[CONTENT]<|im_end|> + mock_tokenizer.apply_chat_template.return_value = [ + 50258, + 100, + 101, + 50257, # user message + 50258, + 200, # <|im_start|>assistant + 50259, # + 1002, + 300, + 301, + 1003, # wrapped reasoning content + 50260, # + 50261, # \n\n + 1000, + 400, + 401, + 402, + 1001, # wrapped content + 50257, # <|im_end|> + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # The unmask regions should be merged, unmasking everything from reasoning to content + # including the and \n\n tokens + expected_unmasked_tokens = [300, 301, 50260, 50261, 400, 401, 402, 50257] + + # Verify that all expected tokens are unmasked + for tok in expected_unmasked_tokens: + if tok in result["input_ids"]: + # For tokens that appear multiple times, check if at least one is unmasked + indices = [i for i, t in enumerate(result["input_ids"]) if t == tok] + assert any(result["labels"][i] == tok for i in indices), ( + f"Token {tok} not unmasked" + ) + + def test_reasoning_content_no_merging_when_different_messages(self, mock_tokenizer): + """Test that regions are not merged when they belong to different messages.""" + messages = [ + { + "role": "assistant", + "content": "Answer1", + "reasoning_content": "Thinking1", + }, + { + "role": "assistant", + "content": "Answer2", + "reasoning_content": "Thinking2", + }, + ] + + # Simulate regions from two different assistant messages + mock_tokenizer.apply_chat_template.return_value = [ + # First message + 1002, + 100, + 1003, # reasoning region for message 1 + 1000, + 200, + 1001, # content region for message 1 + 50257, # EOS + # Second message + 1002, + 300, + 1003, # reasoning region for message 2 + 1000, + 400, + 1001, # content region for message 2 + 50257, # EOS + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # All regions should be unmasked, but first message's regions should be separate from second's + assert 100 in result["input_ids"] # reasoning 1 + assert 200 in result["input_ids"] # content 1 + assert 300 in result["input_ids"] # reasoning 2 + assert 400 in result["input_ids"] # content 2 + + # Both messages' content should be unmasked + for tok in [100, 200, 300, 400]: + idx = result["input_ids"].index(tok) + assert result["labels"][idx] == tok + + def test_multiple_assistant_messages_with_reasoning(self, mock_tokenizer): + """Test handling of multiple assistant messages with reasoning content.""" + messages = [ + {"role": "user", "content": "Question 1"}, + { + "role": "assistant", + "content": "Answer 1", + "reasoning_content": "Thinking 1", + }, + {"role": "user", "content": "Question 2"}, + { + "role": "assistant", + "content": "Answer 2", + "reasoning_content": "Thinking 2", + }, + ] + + # Simulate chat template output with two assistant responses + mock_tokenizer.apply_chat_template.return_value = [ + # First exchange + 50258, + 100, + 50257, # user 1 + 50258, + 200, # assistant start + 50259, # + 1002, + 300, + 1003, # reasoning 1 + 50260, + 50261, # \n\n + 1000, + 400, + 1001, # content 1 + 50257, # <|im_end|> + # Second exchange + 50258, + 500, + 50257, # user 2 + 50258, + 600, # assistant start + 50259, # + 1002, + 700, + 1003, # reasoning 2 + 50260, + 50261, # \n\n + 1000, + 800, + 1001, # content 2 + 50257, # <|im_end|> + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Both assistant messages should have their reasoning and content unmasked + # Check first assistant response + assert 300 in [ + t for i, t in enumerate(result["input_ids"]) if result["labels"][i] != -100 + ] + assert 400 in [ + t for i, t in enumerate(result["input_ids"]) if result["labels"][i] != -100 + ] + + # Check second assistant response + assert 700 in [ + t for i, t in enumerate(result["input_ids"]) if result["labels"][i] != -100 + ] + assert 800 in [ + t for i, t in enumerate(result["input_ids"]) if result["labels"][i] != -100 + ] + + def test_reasoning_without_content(self, mock_tokenizer): + """Test messages that only have reasoning_content without regular content.""" + messages = [ + {"role": "user", "content": "Think about this"}, + {"role": "assistant", "reasoning_content": "Let me think..."}, + ] + + mock_tokenizer.apply_chat_template.return_value = [ + 50258, + 100, + 50257, # user + 50258, + 200, # assistant start + 50259, # + 1002, + 300, + 301, + 1003, # reasoning + 50260, # + 50257, # <|im_end|> + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Reasoning content and EOS should be unmasked + assert 300 in result["input_ids"] + assert 301 in result["input_ids"] + assert 50257 in result["input_ids"] # EOS token + + # Check labels + idx_300 = result["input_ids"].index(300) + idx_301 = result["input_ids"].index(301) + # Find the last EOS token (assistant's) + eos_indices = [i for i, t in enumerate(result["input_ids"]) if t == 50257] + idx_eos = eos_indices[-1] if eos_indices else None + + assert result["labels"][idx_300] == 300 + assert result["labels"][idx_301] == 301 + if idx_eos is not None: + assert result["labels"][idx_eos] == 50257 + + def test_content_without_reasoning(self, mock_tokenizer): + """Test messages that only have content without reasoning_content.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + mock_tokenizer.apply_chat_template.return_value = [ + 50258, + 100, + 50257, # user + 50258, + 200, # assistant start + 1000, + 300, + 301, + 302, + 1001, # content + 50257, # <|im_end|> + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Content and EOS should be unmasked + assert all(tok in result["input_ids"] for tok in [300, 301, 302, 50257]) + + # Verify unmasking + for tok in [300, 301, 302]: + idx = result["input_ids"].index(tok) + assert result["labels"][idx] == tok + + # For EOS token, check the last occurrence (assistant's EOS) + eos_indices = [i for i, t in enumerate(result["input_ids"]) if t == 50257] + assert len(eos_indices) >= 1 + last_eos_idx = eos_indices[-1] + assert result["labels"][last_eos_idx] == 50257 + + def test_reasoning_content_order_variations(self, mock_tokenizer): + """Test different orderings of reasoning and content regions.""" + messages = [ + {"role": "assistant", "content": "Answer", "reasoning_content": "Reasoning"} + ] + + # Test content before reasoning (unusual but possible) + mock_tokenizer.apply_chat_template.return_value = [ + 1000, + 100, + 1001, # content first + 50261, # separator + 1002, + 200, + 1003, # reasoning after + 50257, # EOS + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Both regions should be merged and unmasked + assert 100 in result["input_ids"] + assert 200 in result["input_ids"] + assert 50257 in result["input_ids"] + + def test_unmask_all_roles_with_reasoning(self, mock_tokenizer): + """Test unmasking all roles when some have reasoning content.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + { + "role": "user", + "content": "Question", + "reasoning_content": "User thinking", + }, + { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Assistant thinking", + }, + ] + + mock_tokenizer.apply_chat_template.return_value = [ + # System (no reasoning) + 1000, + 50, + 1001, + # User with reasoning + 1002, + 100, + 1003, + 1000, + 150, + 1001, + # Assistant with reasoning + 1002, + 200, + 1003, + 1000, + 250, + 1001, + 50257, + ] + + result = unmask_messages( + messages, mock_tokenizer, ["system", "user", "assistant"] + ) + + # All content should be unmasked + unmasked_tokens = [50, 100, 150, 200, 250, 50257] + for tok in unmasked_tokens: + idx = result["input_ids"].index(tok) + assert result["labels"][idx] == tok + + def test_edge_case_empty_reasoning_content(self, mock_tokenizer): + """Test handling of empty reasoning_content field.""" + messages = [{"role": "assistant", "content": "Answer", "reasoning_content": ""}] + + # Empty reasoning should still create unmask tokens but with no content between + mock_tokenizer.apply_chat_template.return_value = [ + 1002, + 1003, # empty reasoning + 1000, + 100, + 1001, # content + 50257, + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Content and EOS should be unmasked + assert 100 in result["input_ids"] + assert 50257 in result["input_ids"] + + def test_complex_qwen_deepseek_scenario(self, mock_tokenizer): + """Test a complex scenario mimicking Qwen/DeepSeek behavior.""" + messages = [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": "The answer is 4.", + "reasoning_content": "I need to add 2 and 2. 2 + 2 = 4.", + }, + { + "role": "assistant", + "content": "Let me elaborate: it's basic arithmetic.", + "reasoning_content": "", + }, + ] + + # Simulate the complex template behavior described in the instructions + mock_tokenizer.apply_chat_template.return_value = [ + 50258, + 100, + 50257, # user + 50258, + 200, # first assistant + 50259, # + 1002, + 300, + 301, + 302, + 1003, # reasoning + 50260, # + 50261, # \n\n + 1000, + 400, + 401, + 1001, # content + 50257, # <|im_end|> + 50258, + 500, # second assistant (continuation) + 50259, # (empty) + 1002, + 1003, # empty reasoning + 50260, # + 50261, # \n\n + 1000, + 600, + 601, + 602, + 1001, # content + 50257, # <|im_end|> + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Both assistant messages should be properly unmasked + # First assistant: reasoning + content + for tok in [300, 301, 302, 400, 401]: + assert tok in result["input_ids"] + idx = result["input_ids"].index(tok) + assert result["labels"][idx] == tok + + # Second assistant: content only (empty reasoning) + for tok in [600, 601, 602]: + assert tok in result["input_ids"] + idx = result["input_ids"].index(tok) + assert result["labels"][idx] == tok + + def test_validation_nested_reasoning_tokens(self, mock_tokenizer): + """Test that nested reasoning tokens raise appropriate errors.""" + messages = [{"role": "assistant", "content": "Test"}] + + # Nested reasoning begin tokens + mock_tokenizer.apply_chat_template.return_value = [ + 1002, + 100, + 1002, + 200, + 1003, + 1003, + ] + + with pytest.raises( + ValueError, match="encountered.*UNMASK_REASONING.*while already unmasking" + ): + unmask_messages(messages, mock_tokenizer, ["assistant"]) + + def test_known_token_ids_preserved(self, mock_tokenizer): + """Test that specific known token IDs are handled correctly.""" + messages = [ + { + "role": "assistant", + "content": "Final answer", + "reasoning_content": "Thinking process", + } + ] + + # Use specific token IDs to verify preservation + mock_tokenizer.apply_chat_template.return_value = [ + 50258, # <|im_start|> + 1002, + 12345, + 1003, # reasoning with specific ID + 50260, # + 1000, + 67890, + 1001, # content with specific ID + 50257, # <|im_end|> + ] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Verify specific tokens are preserved + assert 12345 in result["input_ids"] + assert 67890 in result["input_ids"] + + # Verify they're unmasked + idx_12345 = result["input_ids"].index(12345) + idx_67890 = result["input_ids"].index(67890) + assert result["labels"][idx_12345] == 12345 + assert result["labels"][idx_67890] == 67890 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_unmask_messages.py b/tests/unit/test_unmask_messages.py new file mode 100644 index 00000000..700fe0c0 --- /dev/null +++ b/tests/unit/test_unmask_messages.py @@ -0,0 +1,998 @@ +#!/usr/bin/env python3 +""" +Comprehensive tests for the unmask_messages function to ensure behavior consistency +across different tokenizers and scenarios. +""" + +# Standard +from unittest.mock import Mock, patch +import os +import tempfile + +# Third Party +from transformers import AutoTokenizer +import pytest + +# First Party +# Import the functions we want to test +from instructlab.training.data_process import ( + UNMASK_BEGIN_TOKEN, + UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, + unmask_messages, + unmask_sample, + wrap_masked_messages, +) +from instructlab.training.type_definitions import Message + + +class TestWrapMaskedMessages: + """Test suite for wrap_masked_messages functionality.""" + + @pytest.fixture + def sample_messages(self): + """Sample messages for testing.""" + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + + @pytest.fixture + def reasoning_messages(self): + """Sample messages with reasoning content.""" + return [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": "The answer is 4.", + "reasoning_content": "I need to add 2 and 2 together. 2 + 2 = 4.", + }, + ] + + def test_wrap_masked_messages_basic(self, sample_messages): + """Test basic message wrapping functionality.""" + wrapped = wrap_masked_messages(sample_messages, ["assistant"]) + + # Check that only assistant messages are wrapped + assert ( + sample_messages[0]["content"] == wrapped[0]["content"] + ) # system unchanged + assert sample_messages[1]["content"] == wrapped[1]["content"] # user unchanged + assert ( + wrapped[2]["content"] + == f"{UNMASK_BEGIN_TOKEN}I'm doing well, thank you!{UNMASK_END_TOKEN}" + ) + + def test_wrap_masked_messages_with_reasoning(self, reasoning_messages): + """Test message wrapping with reasoning content.""" + # Test with reasoning content disabled (default behavior) + wrapped = wrap_masked_messages(reasoning_messages, ["assistant"]) + + # Check content wrapping + expected_content = f"{UNMASK_BEGIN_TOKEN}The answer is 4.{UNMASK_END_TOKEN}" + assert wrapped[1]["content"] == expected_content + + # Check reasoning content is NOT processed when disabled + assert ( + wrapped[1]["reasoning_content"] + == "I need to add 2 and 2 together. 2 + 2 = 4." + ) + + # Test with reasoning content enabled + wrapped_with_reasoning = wrap_masked_messages( + reasoning_messages, ["assistant"], enable_reasoning_content=True + ) + + # Check content is wrapped with regular tokens and reasoning with reasoning-specific tokens + assert wrapped_with_reasoning[1]["content"] == expected_content + expected_reasoning = f"{UNMASK_REASONING_BEGIN_TOKEN}I need to add 2 and 2 together. 2 + 2 = 4.{UNMASK_REASONING_END_TOKEN}" + assert wrapped_with_reasoning[1]["reasoning_content"] == expected_reasoning + + def test_wrap_masked_messages_multiple_roles(self, sample_messages): + """Test wrapping messages for multiple roles.""" + wrapped = wrap_masked_messages(sample_messages, ["user", "assistant"]) + + # Both user and assistant should be wrapped + assert ( + wrapped[1]["content"] + == f"{UNMASK_BEGIN_TOKEN}Hello, how are you?{UNMASK_END_TOKEN}" + ) + assert ( + wrapped[2]["content"] + == f"{UNMASK_BEGIN_TOKEN}I'm doing well, thank you!{UNMASK_END_TOKEN}" + ) + # System should remain unchanged + assert wrapped[0]["content"] == sample_messages[0]["content"] + + def test_wrap_masked_messages_error_on_non_string_content(self): + """Test that wrapping non-string content raises an error.""" + messages = [{"role": "assistant", "content": ["not", "a", "string"]}] + + with pytest.raises(ValueError, match="unmasking non-string data types"): + wrap_masked_messages(messages, ["assistant"]) + + def test_wrap_masked_messages_empty_roles_list(self, sample_messages): + """Test wrapping with empty roles list.""" + wrapped = wrap_masked_messages(sample_messages, []) + + # All messages should remain unchanged + for i, msg in enumerate(sample_messages): + assert wrapped[i]["content"] == msg["content"] + + def test_wrap_masked_messages_preserves_other_fields(self): + """Test that other message fields are preserved during wrapping.""" + messages = [ + { + "role": "assistant", + "content": "Hello", + "custom_field": "custom_value", + "another_field": 123, + } + ] + wrapped = wrap_masked_messages(messages, ["assistant"]) + + assert wrapped[0]["custom_field"] == "custom_value" + assert wrapped[0]["another_field"] == 123 + assert wrapped[0]["role"] == "assistant" + + +class TestUnmaskMessages: + """Test suite for unmask_messages functionality.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer for basic testing.""" + tokenizer = Mock() + + # Mock the special token encodings + def mock_encode(text, add_special_tokens=False): + token_map = { + UNMASK_BEGIN_TOKEN: [1000], + UNMASK_END_TOKEN: [1001], + UNMASK_REASONING_BEGIN_TOKEN: [1002], + UNMASK_REASONING_END_TOKEN: [1003], + "<|endoftext|>": [0], + } + return token_map.get(text, [hash(text) % 500 + 100]) + + tokenizer.encode = mock_encode + tokenizer.eos_token = "<|endoftext|>" + + return tokenizer + + @pytest.fixture + def simple_input_ids(self): + """Simple token sequence for testing: user msg + assistant msg.""" + # Represents: [role_tokens] user_content [unmask_begin] assistant_content [unmask_end] [eos] + return [50, 51, 52, 1000, 200, 201, 202, 1001, 0] + + def test_unmask_messages_basic_flow(self, mock_tokenizer): + """Test the basic flow of unmask_messages.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Mock apply_chat_template to return our simple input sequence + mock_tokenizer.apply_chat_template.return_value = [50, 1000, 200, 201, 1001, 0] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Verify the result structure + assert "input_ids" in result + assert "labels" in result + assert "len" in result + assert len(result["input_ids"]) == len(result["labels"]) + + # Should not contain unmask tokens in final output + assert 1000 not in result["input_ids"] # UNMASK_BEGIN_TOKEN + assert 1001 not in result["input_ids"] # UNMASK_END_TOKEN + + def test_unmask_messages_assistant_only_unmasking(self, mock_tokenizer): + """Test that only assistant tokens are unmasked when specified.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Sequence: user_tokens + unmask_begin + assistant_tokens + unmask_end + mock_tokenizer.apply_chat_template.return_value = [50, 51, 1000, 200, 201, 1001] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # User tokens should be masked (-100), assistant tokens should be unmasked + expected_input_ids = [50, 51, 200, 201] + expected_labels = [-100, -100, 200, 201] + + assert result["input_ids"] == expected_input_ids + assert result["labels"] == expected_labels + + def test_unmask_messages_multiple_roles(self, mock_tokenizer): + """Test unmasking multiple roles.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Both user and assistant wrapped with unmask tokens + mock_tokenizer.apply_chat_template.return_value = [ + 1000, + 50, + 51, + 1001, # user wrapped + 1000, + 200, + 201, + 1001, # assistant wrapped + ] + + result = unmask_messages(messages, mock_tokenizer, ["user", "assistant"]) + + # Both should be unmasked + expected_input_ids = [50, 51, 200, 201] + expected_labels = [50, 51, 200, 201] + + assert result["input_ids"] == expected_input_ids + assert result["labels"] == expected_labels + + def test_unmask_messages_with_eos_token_for_assistant(self, mock_tokenizer): + """Test that EOS token is unmasked for assistant role.""" + messages = [{"role": "assistant", "content": "Hello"}] + + # Assistant content followed by EOS token + mock_tokenizer.apply_chat_template.return_value = [1000, 200, 201, 1001, 0] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Both assistant content and EOS should be unmasked + expected_input_ids = [200, 201, 0] + expected_labels = [200, 201, 0] + + assert result["input_ids"] == expected_input_ids + assert result["labels"] == expected_labels + + def test_unmask_messages_no_eos_token(self, mock_tokenizer): + """Test behavior when tokenizer has no EOS token.""" + mock_tokenizer.eos_token = None + messages = [{"role": "assistant", "content": "Hello"}] + + mock_tokenizer.apply_chat_template.return_value = [1000, 200, 201, 1001] + + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Should work normally without EOS handling + expected_input_ids = [200, 201] + expected_labels = [200, 201] + + assert result["input_ids"] == expected_input_ids + assert result["labels"] == expected_labels + + def test_unmask_messages_validation_errors(self, mock_tokenizer): + """Test that validation errors are properly raised.""" + messages = [{"role": "assistant", "content": "Hello"}] + + # Simulate a bug where unmask tokens remain in output by mocking a faulty sequence + mock_tokenizer.apply_chat_template.return_value = [ + 1000, + 200, + 1000, + 1001, + ] # nested begin token + + with pytest.raises(ValueError, match="encountered.*while already unmasking"): + unmask_messages(messages, mock_tokenizer, ["assistant"]) + + def test_unmask_messages_mismatched_end_token(self, mock_tokenizer): + """Test error when encountering end token while not unmasking.""" + messages = [{"role": "user", "content": "Hello"}] + + # End token without begin token + mock_tokenizer.apply_chat_template.return_value = [200, 1001] + + with pytest.raises(ValueError, match="encountered.*while not unmasking"): + unmask_messages(messages, mock_tokenizer, ["assistant"]) + + def test_unmask_messages_empty_input(self, mock_tokenizer): + """Test behavior with empty input.""" + mock_tokenizer.apply_chat_template.return_value = [] + + result = unmask_messages([], mock_tokenizer, ["assistant"]) + + assert result["input_ids"] == [] + assert result["labels"] == [] + assert result["len"] == 0 + + def test_unmask_messages_reasoning_content_handling(self, mock_tokenizer): + """Test that reasoning content is properly handled.""" + messages = [ + {"role": "assistant", "content": "Answer", "reasoning_content": "Thinking"} + ] + + # When messages have reasoning content, wrap_masked_messages uses reasoning-specific tokens + mock_tokenizer.apply_chat_template.return_value = [ + 1002, # UNMASK_REASONING_BEGIN + 200, # reasoning content + 1003, # UNMASK_REASONING_END + 1000, # UNMASK_BEGIN + 100, # content + 1001, # UNMASK_END + ] + + # The new implementation correctly handles reasoning content + result = unmask_messages(messages, mock_tokenizer, ["assistant"]) + + # Both reasoning and content should be unmasked + assert result["input_ids"] == [200, 100] + assert result["labels"] == [200, 100] + assert result["len"] == 2 + + +class TestWithRealTokenizers: + """Test with actual tokenizer implementations to ensure realistic behavior.""" + + @pytest.fixture(scope="class") + def test_tokenizer(self): + """Get a small test tokenizer.""" + try: + # Use a small model for testing + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-gpt2") + + # Add the special tokens + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + UNMASK_BEGIN_TOKEN, + UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, + ] + } + ) + + # Set a simple chat template for testing + tokenizer.chat_template = "{% for message in messages %}{{ message['role'] }}: {{ message['content'] }}{% if message.get('reasoning_content') %} [REASONING: {{ message['reasoning_content'] }}]{% endif %}\n{% endfor %}" + + return tokenizer + except Exception as e: + pytest.skip(f"Could not load test tokenizer: {e}") + + def test_real_tokenizer_basic_functionality(self, test_tokenizer): + """Test basic functionality with a real tokenizer.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + result = unmask_messages(messages, test_tokenizer, ["assistant"]) + + # Basic sanity checks + assert len(result["input_ids"]) > 0 + assert len(result["labels"]) == len(result["input_ids"]) + assert result["len"] == len(result["input_ids"]) + + # Check that some tokens are masked (-100) and some are not + masked_count = sum(1 for label in result["labels"] if label == -100) + unmasked_count = len(result["labels"]) - masked_count + + assert masked_count > 0, "Should have some masked tokens" + assert unmasked_count > 0, "Should have some unmasked tokens" + + def test_real_tokenizer_with_reasoning(self, test_tokenizer): + """Test reasoning content with a real tokenizer.""" + messages = [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": "The answer is 4.", + "reasoning_content": "I need to calculate 2+2.", + }, + ] + + result = unmask_messages(messages, test_tokenizer, ["assistant"]) + + # Should have processed both content and reasoning_content + assert len(result["input_ids"]) > 5 # Should be reasonably long + assert result["len"] == len(result["input_ids"]) + + # Should have both masked and unmasked tokens + masked_count = sum(1 for label in result["labels"] if label == -100) + unmasked_count = len(result["labels"]) - masked_count + + assert masked_count > 0, "Should have some masked tokens (user message)" + assert unmasked_count > 0, ( + "Should have some unmasked tokens (assistant content + reasoning)" + ) + + def test_real_tokenizer_edge_cases(self, test_tokenizer): + """Test edge cases with real tokenizer.""" + # Empty assistant message + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": ""}, + ] + + result = unmask_messages(messages, test_tokenizer, ["assistant"]) + assert len(result["input_ids"]) > 0 + + # Only reasoning content, no regular content + messages = [ + {"role": "user", "content": "Think about this"}, + {"role": "assistant", "reasoning_content": "Let me think..."}, + ] + + result = unmask_messages(messages, test_tokenizer, ["assistant"]) + assert len(result["input_ids"]) > 0 + + # Should have some unmasked tokens from reasoning + unmasked_count = sum(1 for label in result["labels"] if label != -100) + assert unmasked_count > 0, "Should have unmasked reasoning content" + + +class TestErrorConditions: + """Test various error conditions and edge cases.""" + + @pytest.fixture + def mock_tokenizer(self): + """Mock tokenizer for error testing.""" + tokenizer = Mock() + tokenizer.encode.side_effect = lambda text, add_special_tokens=False: { + UNMASK_BEGIN_TOKEN: [1000], + UNMASK_END_TOKEN: [1001], + }.get(text, [100]) + tokenizer.eos_token = None + return tokenizer + + def test_length_mismatch_error(self, mock_tokenizer): + """Test that length mismatches raise appropriate errors.""" + # This would be an internal error where our processing logic fails + messages = [{"role": "assistant", "content": "Hello"}] + mock_tokenizer.apply_chat_template.return_value = [1000, 200, 1001] + + # Mock a scenario where we somehow create mismatched lengths + with patch("instructlab.training.data_process.unmask_messages") as mock_unmask: + mock_unmask.side_effect = RuntimeError( + "final_input_ids and final_labels are not the same length" + ) + + with pytest.raises( + RuntimeError, + match="final_input_ids and final_labels are not the same length", + ): + mock_unmask(messages, mock_tokenizer, ["assistant"]) + + def test_unfinished_unmasking_error(self, mock_tokenizer): + """Test error when unmasking is not properly finished.""" + messages = [{"role": "assistant", "content": "Hello"}] + + # Begin token without end token + mock_tokenizer.apply_chat_template.return_value = [1000, 200] + + with pytest.raises( + RuntimeError, match="unmasking finished but not all messages were processed" + ): + unmask_messages(messages, mock_tokenizer, ["assistant"]) + + +class TestRealTokenizersUnmaskBehavior: + """Test suite for validating unmask behavior with real tokenizers.""" + + @pytest.fixture( + scope="class", params=["Qwen/Qwen3-32B", "ibm-granite/granite-3.1-8b-instruct"] + ) + def real_tokenizer(self, request): + """Load real tokenizers for comprehensive testing.""" + try: + # Add environment variable for testing without downloading + os.environ["TRANSFORMERS_OFFLINE"] = "0" + + tokenizer = AutoTokenizer.from_pretrained(request.param, cache_dir=".cache") + + # Add the special unmask tokens + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + UNMASK_BEGIN_TOKEN, + UNMASK_END_TOKEN, + UNMASK_REASONING_BEGIN_TOKEN, + UNMASK_REASONING_END_TOKEN, + ] + } + ) + + # Store the model name for test identification + tokenizer._model_name = request.param + + return tokenizer + except Exception as e: + pytest.skip(f"Could not load tokenizer {request.param}: {e}") + + @pytest.fixture + def sample_unmask_true(self): + """Sample with unmask: True - should unmask user and assistant, but not system.""" + return { + "messages": [ + { + "role": "system", + "content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", + }, + { + "role": "user", + "content": 'For the word "dream", give an example of a word that rhymes with it and its synonym.', + }, + { + "role": "assistant", + "content": 'Here\'s an example for "dream" that includes a word that rhymes with it and a synonym:\n1. Word that rhymes with "dream": "beam"\nSynonym: "ideal"', + }, + ], + "unmask": True, + } + + @pytest.fixture + def sample_unmask_false(self): + """Sample with unmask: False - should only unmask assistant.""" + return { + "messages": [ + { + "role": "system", + "content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", + }, + { + "role": "user", + "content": 'Using the word "grace", come up with a word that rhymes and has the same number of syllables', + }, + { + "role": "assistant", + "content": 'Certainly! Here\'s a word that rhymes with "grace" and has the same number of syllables:\n1. Space', + }, + ], + "unmask": False, + } + + @pytest.fixture + def sample_with_reasoning(self): + """Sample with reasoning content.""" + return { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": "The answer is 4.", + "reasoning_content": "I need to add 2 and 2 together. 2 + 2 = 4.", + }, + ], + "unmask": False, + } + + @pytest.mark.slow + def test_unmask_sample_with_unmask_true(self, real_tokenizer, sample_unmask_true): + """Test that unmask: True correctly unmasks user and assistant but not system.""" + result = unmask_sample(sample_unmask_true, real_tokenizer) + + # Basic validation + assert "input_ids" in result + assert "labels" in result + assert "len" in result + assert len(result["input_ids"]) == len(result["labels"]) + assert result["len"] == len(result["input_ids"]) + + # Decode the sequences to validate content + input_text = real_tokenizer.decode( + result["input_ids"], skip_special_tokens=False + ) + + # Should contain parts of user and assistant content but not system in labels + masked_positions = [ + i for i, label in enumerate(result["labels"]) if label == -100 + ] + unmasked_positions = [ + i for i, label in enumerate(result["labels"]) if label != -100 + ] + + # Must have both masked and unmasked tokens + assert len(masked_positions) > 0, ( + f"Expected some masked tokens for {real_tokenizer._model_name}" + ) + assert len(unmasked_positions) > 0, ( + f"Expected some unmasked tokens for {real_tokenizer._model_name}" + ) + + # Verify that unmasked tokens match input_ids + for pos in unmasked_positions: + assert result["labels"][pos] == result["input_ids"][pos], ( + f"Unmasked position {pos} should have matching label and input_id" + ) + + # Check that unmask tokens are not present in final output + assert UNMASK_BEGIN_TOKEN.encode() not in input_text.encode() + assert UNMASK_END_TOKEN.encode() not in input_text.encode() + + print(f"\n=== {real_tokenizer._model_name} - UNMASK: TRUE ===") + print(f"Input text: {input_text}") + print(f"Total tokens: {len(result['input_ids'])}") + print(f"Masked tokens: {len(masked_positions)}") + print(f"Unmasked tokens: {len(unmasked_positions)}") + + # Create a visual representation of masking + visual_labels = [] + for i, (token_id, label) in enumerate( + zip(result["input_ids"], result["labels"]) + ): + token_text = real_tokenizer.decode([token_id]) + if label == -100: + visual_labels.append("<|MASK|>") + else: + visual_labels.append(token_text) + + print(f"Visual masking: {''.join(visual_labels)}") + + @pytest.mark.slow + def test_unmask_sample_with_unmask_false(self, real_tokenizer, sample_unmask_false): + """Test that unmask: False correctly unmasks only assistant.""" + result = unmask_sample(sample_unmask_false, real_tokenizer) + + # Basic validation + assert "input_ids" in result + assert "labels" in result + assert "len" in result + assert len(result["input_ids"]) == len(result["labels"]) + assert result["len"] == len(result["input_ids"]) + + # Decode the sequences to validate content + input_text = real_tokenizer.decode( + result["input_ids"], skip_special_tokens=False + ) + + # Should have more masked tokens than unmask=True case since only assistant is unmasked + masked_positions = [ + i for i, label in enumerate(result["labels"]) if label == -100 + ] + unmasked_positions = [ + i for i, label in enumerate(result["labels"]) if label != -100 + ] + + # Must have both masked and unmasked tokens + assert len(masked_positions) > 0, ( + f"Expected some masked tokens for {real_tokenizer._model_name}" + ) + assert len(unmasked_positions) > 0, ( + f"Expected some unmasked tokens for {real_tokenizer._model_name}" + ) + + # Verify that unmasked tokens match input_ids + for pos in unmasked_positions: + assert result["labels"][pos] == result["input_ids"][pos], ( + f"Unmasked position {pos} should have matching label and input_id" + ) + + # Check that unmask tokens are not present in final output + assert UNMASK_BEGIN_TOKEN.encode() not in input_text.encode() + assert UNMASK_END_TOKEN.encode() not in input_text.encode() + + print(f"\n=== {real_tokenizer._model_name} - UNMASK: FALSE ===") + print(f"Input text: {input_text}") + print(f"Total tokens: {len(result['input_ids'])}") + print(f"Masked tokens: {len(masked_positions)}") + print(f"Unmasked tokens: {len(unmasked_positions)}") + + # Create a visual representation of masking + visual_labels = [] + for i, (token_id, label) in enumerate( + zip(result["input_ids"], result["labels"]) + ): + token_text = real_tokenizer.decode([token_id]) + if label == -100: + visual_labels.append("<|MASK|>") + else: + visual_labels.append(token_text) + + print(f"Visual masking: {''.join(visual_labels)}") + + @pytest.mark.slow + def test_unmask_comparison_between_settings( + self, real_tokenizer, sample_unmask_true, sample_unmask_false + ): + """Test that unmask: True results in fewer masked tokens than unmask: False.""" + result_true = unmask_sample(sample_unmask_true, real_tokenizer) + result_false = unmask_sample(sample_unmask_false, real_tokenizer) + + masked_count_true = sum(1 for label in result_true["labels"] if label == -100) + masked_count_false = sum(1 for label in result_false["labels"] if label == -100) + + unmasked_count_true = len(result_true["labels"]) - masked_count_true + unmasked_count_false = len(result_false["labels"]) - masked_count_false + + # unmask: True should have more unmasked tokens (user + assistant vs just assistant) + assert unmasked_count_true > unmasked_count_false, ( + f"unmask: True should unmask more tokens than unmask: False for {real_tokenizer._model_name}" + ) + + @pytest.mark.slow + def test_unmask_with_reasoning_content(self, real_tokenizer, sample_with_reasoning): + """Test that reasoning content is properly handled.""" + result = unmask_sample(sample_with_reasoning, real_tokenizer) + + # Basic validation + assert "input_ids" in result + assert "labels" in result + assert "len" in result + assert len(result["input_ids"]) == len(result["labels"]) + + # Should have processed both content and reasoning_content + unmasked_positions = [ + i for i, label in enumerate(result["labels"]) if label != -100 + ] + assert len(unmasked_positions) > 0, ( + "Should have unmasked tokens from assistant content and reasoning" + ) + + # Decode to see the result + input_text = real_tokenizer.decode( + result["input_ids"], skip_special_tokens=False + ) + print(f"\n=== {real_tokenizer._model_name} - WITH REASONING ===") + print(f"Input text: {input_text}") + print(f"Total tokens: {len(result['input_ids'])}") + print(f"Unmasked tokens: {len(unmasked_positions)}") + + def test_token_id_consistency(self, real_tokenizer, sample_unmask_true): + """Test that token IDs are consistent and valid.""" + result = unmask_sample(sample_unmask_true, real_tokenizer) + + # All input_ids should be valid token IDs + for token_id in result["input_ids"]: + assert isinstance(token_id, int), "All token IDs should be integers" + assert 0 <= token_id < len(real_tokenizer), ( + "Token IDs should be within vocabulary range" + ) + + # All non-masked labels should match their corresponding input_ids + for i, (input_id, label) in enumerate( + zip(result["input_ids"], result["labels"]) + ): + if label != -100: + assert label == input_id, ( + f"Position {i}: label {label} should match input_id {input_id}" + ) + + # Verify we can decode all tokens + decoded_text = real_tokenizer.decode(result["input_ids"]) + assert isinstance(decoded_text, str), ( + "Should be able to decode all tokens to string" + ) + assert len(decoded_text) > 0, "Decoded text should not be empty" + + def test_special_tokens_removed_from_output( + self, real_tokenizer, sample_unmask_true + ): + """Test that special unmask tokens are properly removed from final output.""" + result = unmask_sample(sample_unmask_true, real_tokenizer) + + # Get token IDs for special tokens + unmask_begin_id = real_tokenizer.encode( + UNMASK_BEGIN_TOKEN, add_special_tokens=False + )[0] + unmask_end_id = real_tokenizer.encode( + UNMASK_END_TOKEN, add_special_tokens=False + )[0] + unmask_reasoning_begin_id = real_tokenizer.encode( + UNMASK_REASONING_BEGIN_TOKEN, add_special_tokens=False + )[0] + unmask_reasoning_end_id = real_tokenizer.encode( + UNMASK_REASONING_END_TOKEN, add_special_tokens=False + )[0] + + # None of these should appear in final output + assert unmask_begin_id not in result["input_ids"], ( + "UNMASK_BEGIN_TOKEN should not be in final input_ids" + ) + assert unmask_end_id not in result["input_ids"], ( + "UNMASK_END_TOKEN should not be in final input_ids" + ) + assert unmask_reasoning_begin_id not in result["input_ids"], ( + "UNMASK_REASONING_BEGIN_TOKEN should not be in final input_ids" + ) + assert unmask_reasoning_end_id not in result["input_ids"], ( + "UNMASK_REASONING_END_TOKEN should not be in final input_ids" + ) + + # Same for labels + assert unmask_begin_id not in result["labels"], ( + "UNMASK_BEGIN_TOKEN should not be in final labels" + ) + assert unmask_end_id not in result["labels"], ( + "UNMASK_END_TOKEN should not be in final labels" + ) + assert unmask_reasoning_begin_id not in result["labels"], ( + "UNMASK_REASONING_BEGIN_TOKEN should not be in final labels" + ) + assert unmask_reasoning_end_id not in result["labels"], ( + "UNMASK_REASONING_END_TOKEN should not be in final labels" + ) + + def test_reproducibility(self, real_tokenizer, sample_unmask_true): + """Test that the same input produces the same output consistently.""" + result1 = unmask_sample(sample_unmask_true, real_tokenizer) + result2 = unmask_sample(sample_unmask_true, real_tokenizer) + + assert result1["input_ids"] == result2["input_ids"], ( + "Results should be reproducible" + ) + assert result1["labels"] == result2["labels"], "Results should be reproducible" + assert result1["len"] == result2["len"], "Results should be reproducible" + + +class TestUnmaskSampleLogic: + """Test the logic of unmask_sample without requiring full tokenizer loading.""" + + @pytest.fixture + def mock_tokenizer_for_unmask_sample(self): + """Create a comprehensive mock tokenizer for testing unmask_sample logic.""" + tokenizer = Mock() + + # Mock the special token encodings - using unique IDs for each token + def mock_encode(text, add_special_tokens=False): + token_map = { + UNMASK_BEGIN_TOKEN: [1000], + UNMASK_END_TOKEN: [1001], + UNMASK_REASONING_BEGIN_TOKEN: [1002], + UNMASK_REASONING_END_TOKEN: [1003], + "<|endoftext|>": [0], + } + # Return predictable token IDs based on hash for consistent testing + return token_map.get(text, [abs(hash(text)) % 500 + 100]) + + tokenizer.encode = mock_encode + tokenizer.eos_token = "<|endoftext|>" + + # Mock apply_chat_template to return a sequence that represents: + # system_tokens + unmask_begin + user_tokens + unmask_end + unmask_begin + assistant_tokens + unmask_end + eos + def mock_apply_chat_template(messages, **kwargs): + sequence = [] + for msg in messages: + role = msg["role"] + content = msg.get("content", "") + reasoning_content = msg.get("reasoning_content", "") + + # Add role-specific tokens + if role == "system": + sequence.extend([10, 11, 12]) # system role tokens + elif role == "user": + sequence.extend([20, 21]) # user role tokens + elif role == "assistant": + sequence.extend([30, 31]) # assistant role tokens + + # Add reasoning content if present and wrapped + if ( + reasoning_content + and UNMASK_REASONING_BEGIN_TOKEN in reasoning_content + ): + sequence.append(1002) # reasoning begin + sequence.extend([250, 251, 252]) # reasoning content tokens + sequence.append(1003) # reasoning end + + # Add content tokens + if content and UNMASK_BEGIN_TOKEN in content: + sequence.append(1000) # unmask begin + # Add content tokens based on role + if role == "user": + sequence.extend([200, 201, 202]) + elif role == "assistant": + sequence.extend([300, 301, 302]) + sequence.append(1001) # unmask end + elif content: + # Non-wrapped content + if role == "system": + sequence.extend([100, 101, 102, 103]) + elif role == "user": + sequence.extend([200, 201, 202]) + elif role == "assistant": + sequence.extend([300, 301, 302]) + + sequence.append(0) # eos token + return sequence + + tokenizer.apply_chat_template = mock_apply_chat_template + return tokenizer + + def test_unmask_sample_unmask_false_logic(self, mock_tokenizer_for_unmask_sample): + """Test unmask: False logic - should only unmask assistant role.""" + sample = { + "messages": [ + {"role": "system", "content": "System message"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"}, + ], + "unmask": False, + } + + result = unmask_sample(sample, mock_tokenizer_for_unmask_sample) + + # Basic validation + assert len(result["input_ids"]) == len(result["labels"]) + + # Count masked vs unmasked tokens + masked_count = sum(1 for label in result["labels"] if label == -100) + unmasked_count = len(result["labels"]) - masked_count + + assert masked_count > 0, "Should have some masked tokens" + assert unmasked_count > 0, "Should have some unmasked tokens (assistant)" + + # Verify that assistant tokens are unmasked + # The mock returns assistant content as tokens [300, 301, 302] + assistant_tokens = [300, 301, 302] + for token in assistant_tokens: + if token in result["input_ids"]: + idx = result["input_ids"].index(token) + assert result["labels"][idx] == token, ( + f"Assistant token {token} should be unmasked" + ) + + def test_unmask_sample_unmask_true_logic(self, mock_tokenizer_for_unmask_sample): + """Test unmask: True logic - should unmask user and assistant, but not system.""" + sample = { + "messages": [ + {"role": "system", "content": "System message"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"}, + ], + "unmask": True, + } + + result = unmask_sample(sample, mock_tokenizer_for_unmask_sample) + + # Basic validation + assert len(result["input_ids"]) == len(result["labels"]) + + # Count masked vs unmasked tokens + masked_count = sum(1 for label in result["labels"] if label == -100) + unmasked_count = len(result["labels"]) - masked_count + + assert masked_count > 0, "Should have some masked tokens (system)" + assert unmasked_count > 0, "Should have some unmasked tokens (user + assistant)" + + # Verify that both user and assistant tokens are unmasked + user_tokens = [200, 201, 202] + assistant_tokens = [300, 301, 302] + + for token in user_tokens + assistant_tokens: + if token in result["input_ids"]: + idx = result["input_ids"].index(token) + assert result["labels"][idx] == token, ( + f"User/Assistant token {token} should be unmasked" + ) + + def test_unmask_sample_comparison(self, mock_tokenizer_for_unmask_sample): + """Test that unmask: True unmasks more tokens than unmask: False.""" + sample_base = { + "messages": [ + {"role": "system", "content": "System message"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"}, + ] + } + + sample_false = {**sample_base, "unmask": False} + sample_true = {**sample_base, "unmask": True} + + result_false = unmask_sample(sample_false, mock_tokenizer_for_unmask_sample) + result_true = unmask_sample(sample_true, mock_tokenizer_for_unmask_sample) + + unmasked_false = sum(1 for label in result_false["labels"] if label != -100) + unmasked_true = sum(1 for label in result_true["labels"] if label != -100) + + assert unmasked_true > unmasked_false, ( + "unmask: True should unmask more tokens than unmask: False" + ) + + print(f"\nUnmask comparison:") + print(f"unmask: False -> {unmasked_false} unmasked tokens") + print(f"unmask: True -> {unmasked_true} unmasked tokens") + print( + f"Difference: +{unmasked_true - unmasked_false} more tokens unmasked with unmask: True" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])