From 2667a08fa738f97bd1b96b7226ab3bdc271d32e4 Mon Sep 17 00:00:00 2001 From: 0xrushi <6279035+0xrushi@users.noreply.github.com> Date: Sat, 11 Oct 2025 05:40:44 -0400 Subject: [PATCH] feat: implement constrained selection parser and logprobs support --- .../components/model_client/openai_client.py | 74 +++++- .../components/output_parsers/__init__.py | 3 + .../output_parsers/constrained_parser.py | 235 ++++++++++++++++++ adalflow/adalflow/core/generator.py | 98 ++++++++ adalflow/adalflow/core/model_client.py | 33 ++- adalflow/adalflow/core/types.py | 6 +- adalflow/tests/test_constrained_parser.py | 38 +++ adalflow/tests/test_logprob_support.py | 139 +++++++++++ 8 files changed, 623 insertions(+), 3 deletions(-) create mode 100644 adalflow/adalflow/components/output_parsers/constrained_parser.py create mode 100644 adalflow/tests/test_constrained_parser.py create mode 100644 adalflow/tests/test_logprob_support.py diff --git a/adalflow/adalflow/components/model_client/openai_client.py b/adalflow/adalflow/components/model_client/openai_client.py index 6d7c4fe49..e3c6c6d55 100644 --- a/adalflow/adalflow/components/model_client/openai_client.py +++ b/adalflow/adalflow/components/model_client/openai_client.py @@ -24,7 +24,7 @@ # optional import from adalflow.utils.lazy_import import safe_import, OptionalPackages - +from adalflow.core.types import TokenLogProb openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) @@ -1117,6 +1117,78 @@ async def acall( else: raise ValueError(f"model_type {model_type} is not supported") + def _extract_logprobs(self, completion: Any) -> List[List["TokenLogProb"]]: + """Extract logprobs from OpenAI completion response. + + Args: + completion: OpenAI completion response + + Returns: + List of token logprobs for each choice + """ + logprobs = [] + + try: + if hasattr(completion, "choices"): + for choice in completion.choices: + if hasattr(choice, "logprobs") and choice.logprobs: + choice_logprobs = [] + for token_logprob in choice.logprobs.content: + choice_logprobs.append( + TokenLogProb( + token=token_logprob.token, + logprob=token_logprob.logprob, + choice_index=getattr(choice, "index", None), + ) + ) + logprobs.append(choice_logprobs) + else: + logprobs.append([]) + except Exception as e: + log.error(f"Failed to extract logprobs: {e}") + raise e + + return logprobs + + def call_with_logprobs( + self, + input: str = "", + model_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED + ) -> tuple[Any, List[List[TokenLogProb]]]: + """Call the API with logprobs enabled for constrained generation. + + This method uses the traditional chat.completions API which supports logprobs, + instead of the new responses API which doesn't. + + Args: + input: The input text to process + model_kwargs: Model parameters + model_type: Type of model call + + Returns: + Tuple of (completion, logprobs) where logprobs is a list of token logprobs + """ + + log.debug("call with logprobs using chat.completions") + + messages = [{"role": "user", "content": str(input)}] + + chat_kwargs = { + "model": model_kwargs.get("model", "gpt-3.5-turbo"), + "messages": messages, + "logprobs": True, + "top_logprobs": 5, + "temperature": model_kwargs.get("temperature", 0.1), + "max_tokens": model_kwargs.get("max_tokens", 1000), + } + + completion = self.sync_client.chat.completions.create(**chat_kwargs) + + logprobs = self._extract_logprobs(completion) + + return completion, logprobs + @classmethod def from_dict(cls: type[T], data: Dict[str, Any]) -> T: obj = super().from_dict(data) diff --git a/adalflow/adalflow/components/output_parsers/__init__.py b/adalflow/adalflow/components/output_parsers/__init__.py index 3eeb77453..ed44df326 100644 --- a/adalflow/adalflow/components/output_parsers/__init__.py +++ b/adalflow/adalflow/components/output_parsers/__init__.py @@ -7,6 +7,7 @@ LIST_OUTPUT_FORMAT, ) from .dataclass_parser import DataClassParser +from .constrained_parser import ConstrainedSelectionParser, MultiChoiceParser __all__ = [ "YamlOutputParser", @@ -16,4 +17,6 @@ "JSON_OUTPUT_FORMAT", "LIST_OUTPUT_FORMAT", "DataClassParser", + "ConstrainedSelectionParser", + "MultiChoiceParser", ] diff --git a/adalflow/adalflow/components/output_parsers/constrained_parser.py b/adalflow/adalflow/components/output_parsers/constrained_parser.py new file mode 100644 index 000000000..bd3de3cd0 --- /dev/null +++ b/adalflow/adalflow/components/output_parsers/constrained_parser.py @@ -0,0 +1,235 @@ +"""Constrained Selection Parser for forcing models to select from specific options. + +This parser uses logprobs to force models to select from a predefined set of options, +similar to the Guidance library's constrained generation capabilities. +""" + +import logging +from typing import List, Dict, Any, Optional, Union + +from adalflow.core.component import DataComponent +from adalflow.core.types import TokenLogProb + +log = logging.getLogger(__name__) + + +class ConstrainedSelectionParser(DataComponent): + """Parser that forces model to select from a constrained set of options using logprobs. + + This parser is particularly useful for classification tasks where you want to ensure + the model only outputs one of the predefined options. + + Example: + parser = ConstrainedSelectionParser( + options=["positive", "negative", "neutral"], + allow_partial_match=True + ) + + # Use with Generator + generator = Generator( + model_client=model_client, + output_processors=parser, + # ... other args + ) + """ + + def __init__( + self, + options: List[str], + allow_partial_match: bool = True, + case_sensitive: bool = False, + max_tokens: int = 10, + temperature: float = 0.0, + ): + """Initialize the constrained selection parser. + + Args: + options: List of valid options the model can select from + allow_partial_match: Whether to allow partial matches (e.g., "pos" matches "positive") + case_sensitive: Whether option matching is case sensitive + max_tokens: Maximum number of tokens to consider for selection + temperature: Temperature for logprob-based selection (0.0 = deterministic) + """ + super().__init__() + + if not options or len(options) == 0: + raise ValueError("Options list cannot be empty") + + self.options = options + self.allow_partial_match = allow_partial_match + self.case_sensitive = case_sensitive + self.max_tokens = max_tokens + self.temperature = temperature + + # Normalize options for matching + if not case_sensitive: + self.normalized_options = {opt.lower(): opt for opt in options} + else: + self.normalized_options = {opt: opt for opt in options} + + def _normalize_text(self, text: str) -> str: + """Normalize text for matching.""" + return text if self.case_sensitive else text.lower() + + def _find_best_match(self, text: str) -> Optional[str]: + """Find the best matching option for the given text.""" + normalized_text = self._normalize_text(text.strip()) + + # Exact match first + if normalized_text in self.normalized_options: + return self.normalized_options[normalized_text] + + if not self.allow_partial_match: + return None + + # Partial match + for normalized_option, original_option in self.normalized_options.items(): + if normalized_option in normalized_text or normalized_text in normalized_option: + return original_option + + return None + + def _select_from_logprobs(self, logprobs: List[List[TokenLogProb]]) -> Optional[str]: + """Select the best option using logprobs.""" + if not logprobs or len(logprobs) == 0: + return None + + # Flatten all logprobs + all_tokens = [] + for token_list in logprobs: + all_tokens.extend(token_list) + + if not all_tokens: + return None + + # Calculate scores for each option + option_scores = {} + + for option in self.options: + normalized_option = self._normalize_text(option) + option_words = normalized_option.split() + + # score based on token probabilities + score = 0.0 + matched_tokens = 0 + + for i, token in enumerate(all_tokens[:self.max_tokens]): + token_text = self._normalize_text(token.token) + + for word in option_words: + if word in token_text or token_text in word: + # Use logprob as score (higher is better) + score += token.logprob + matched_tokens += 1 + break + + if matched_tokens > 0: + # normalize score by number of matched tokens + option_scores[option] = score / matched_tokens + + if not option_scores: + return None + + best_option = max(option_scores.items(), key=lambda x: x[1])[0] + return best_option + + def call(self, input_data: Union[str, Dict[str, Any]]) -> str: + """Parse input and return the best matching option. + + Args: + input_data: Either a string response or a dict with 'response' and 'logprobs' keys + + Returns: + The best matching option from the predefined list + """ + if isinstance(input_data, str): + # Simple text-based matching + result = self._find_best_match(input_data) + if result is None: + log.warning(f"No matching option found for: {input_data}") + return self.options[0] # Return first option as fallback + return result + + elif isinstance(input_data, dict): + # Use logprobs if available + if 'logprobs' in input_data and input_data['logprobs']: + result = self._select_from_logprobs(input_data['logprobs']) + if result is not None: + return result + + # Fallback to text matching + response_text = input_data.get('response', '') + if response_text: + result = self._find_best_match(response_text) + if result is not None: + return result + + log.warning(f"No matching option found in: {input_data}") + return self.options[0] + + else: + log.warning(f"Unsupported input type: {type(input_data)}") + return self.options[0] + + def get_format_instructions(self) -> str: + """Get format instructions for the prompt.""" + options_str = ", ".join([f'"{opt}"' for opt in self.options]) + + return f"""You must respond with exactly one of these options: {options_str} + +Rules: +- Choose the option that best matches your analysis +- Respond with only the option text, no additional explanation +- If unsure, choose the most appropriate option from the list +- Your response must be one of: {options_str}""" + + def _extra_repr(self) -> str: + return f"options={self.options}, allow_partial_match={self.allow_partial_match}, case_sensitive={self.case_sensitive}" + + +class MultiChoiceParser(ConstrainedSelectionParser): + """Parser for multiple choice questions with options A, B, C, D, etc.""" + + def __init__( + self, + num_choices: int = 4, + choice_format: str = "letter", # "letter" or "number" + **kwargs + ): + """Initialize multi-choice parser. + + Args: + num_choices: Number of choices (default 4 for A, B, C, D) + choice_format: Format of choices - "letter" (A, B, C, D) or "number" (1, 2, 3, 4) + """ + if choice_format == "letter": + options = [chr(65 + i) for i in range(num_choices)] # A, B, C, D, ... + elif choice_format == "number": + options = [str(i + 1) for i in range(num_choices)] # 1, 2, 3, 4, ... + else: + raise ValueError("choice_format must be 'letter' or 'number'") + + super().__init__(options=options, **kwargs) + self.num_choices = num_choices + self.choice_format = choice_format + + def get_format_instructions(self) -> str: + """Get format instructions for multiple choice.""" + if self.choice_format == "letter": + choices = [chr(65 + i) for i in range(self.num_choices)] + return f"""You must respond with exactly one letter: {', '.join(choices)} + +Rules: +- Choose the letter that corresponds to the best answer +- Respond with only the letter (A, B, C, D, etc.) +- No additional text or explanation +- Your response must be one of: {', '.join(choices)}""" + else: + choices = [str(i + 1) for i in range(self.num_choices)] + return f"""You must respond with exactly one number: {', '.join(choices)} + +Rules: +- Choose the number that corresponds to the best answer +- Respond with only the number (1, 2, 3, 4, etc.) +- No additional text or explanation +- Your response must be one of: {', '.join(choices)}""" diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index 0384b3a72..8f6fafad0 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -1431,6 +1431,104 @@ def failure_message_to_backward_engine( response_value = f"Error: {gradient_response.error}, Raw response: {gradient_response.raw_response}" return response_value + def select_from_options( + self, + options: List[str], + prompt_kwargs: Optional[Dict[str, Union[str, Parameter]]] = {}, + model_kwargs: Optional[Dict] = {}, + allow_partial_match: bool = True, + case_sensitive: bool = False, + id: Optional[str] = None, + ) -> str: + """Select from a constrained set of options using logprobs. + + This method forces the model to select from a predefined set of options, + similar to the Guidance library's constrained generation capabilities. + + Args: + options: List of valid options the model can select from + prompt_kwargs: Prompt arguments to fill in the template + model_kwargs: Model arguments for the API call + allow_partial_match: Whether to allow partial matches + case_sensitive: Whether option matching is case sensitive + id: Optional ID for tracing + + Returns: + The selected option from the predefined list + + Example: + generator = Generator( + model_client=OpenAIClient(), + model_kwargs={"model": "gpt-3.5-turbo"} + ) + + # Force selection from specific options + result = generator.select_from_options( + options=["positive", "negative", "neutral"], + prompt_kwargs={"input_str": "How do you feel about this?"} + ) + # Returns one of: "positive", "negative", "neutral" + """ + from adalflow.components.output_parsers import ConstrainedSelectionParser + + parser = ConstrainedSelectionParser( + options=options, + allow_partial_match=allow_partial_match, + case_sensitive=case_sensitive + ) + + selection_instructions = parser.get_format_instructions() + + original_template = self.template + modified_template = f"{original_template}\n\n{selection_instructions}" + + self.template = modified_template + + try: + # Use logprob-based selection when the client implements it + try: + input_text = self.get_prompt(**(prompt_kwargs or {})) + composed_model_kwargs = self._compose_model_kwargs( + **(model_kwargs or {}) + ) + completion, logprobs = self.model_client.call_with_logprobs( + input=input_text, + model_kwargs=composed_model_kwargs, + model_type=self.model_type, + ) + + try: + response_text = completion.choices[0].message.content + except Exception: + response_text = str(completion) + + return parser.call( + {"response": response_text, "logprobs": logprobs} + ) + except NotImplementedError: + log.debug( + "Model client does not implement logprobs; falling back to text matching." + ) + except AttributeError: + log.debug( + "Model client has no logprob support; falling back to text matching." + ) + except Exception as e: + log.error(f"Error calling model client with logprobs: {e}") + raise + + # Fallback to regular call with text-based matching + result = self.call( + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + id=id, + ) + if hasattr(result, "data") and result.data is not None: + return parser.call(result.data) + return parser.call(str(result)) + finally: + self.template = original_template + class BackwardEngine(Generator): # it is a generator with defaule template diff --git a/adalflow/adalflow/core/model_client.py b/adalflow/adalflow/core/model_client.py index a967ca670..a304e2ba2 100644 --- a/adalflow/adalflow/core/model_client.py +++ b/adalflow/adalflow/core/model_client.py @@ -1,6 +1,7 @@ r"""ModelClient is the protocol and base class for all models(either via APIs or local models) to communicate with components.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List +from adalflow.core.types import TokenLogProb from adalflow.core.component import DataComponent @@ -125,3 +126,33 @@ def list_models(self): raise NotImplementedError( f"{type(self).__name__} must implement list_models method" ) + + def call_with_logprobs( + self, + model_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED + ) -> tuple[Any, List[List["TokenLogProb"]]]: + """Call the API with logprobs enabled for constrained generation. + + Args: + model_kwargs: Model-specific keyword arguments. Implementations should + enable logprobs in the underlying request. + model_type: Type of model call + + Returns: + Tuple of (completion, logprobs) where logprobs is a list of token logprob + objects for each choice. Each TokenLogProb includes the associated + ``choice_index`` when available. + """ + raise NotImplementedError( + f"{type(self).__name__} (Optional) must implement call_with_logprobs method" + ) + + def _extract_logprobs(self, completion: Any) -> List[List["TokenLogProb"]]: + """Extract logprobs from completion response. + + This method should be overridden by subclasses to handle their specific API format. + """ + raise NotImplementedError( + f"{type(self).__name__} (Optional) must implement _extract_logprobs method" + ) diff --git a/adalflow/adalflow/core/types.py b/adalflow/adalflow/core/types.py index 8d7b4f2b5..2dd37d4c3 100644 --- a/adalflow/adalflow/core/types.py +++ b/adalflow/adalflow/core/types.py @@ -219,10 +219,14 @@ def is_normalized(self) -> bool: ###################################################################################### @dataclass class TokenLogProb: - r"""similar to openai.ChatCompletionTokenLogprob""" + r"""Represents a single token log probability returned by a model.""" token: str logprob: float + choice_index: Optional[int] = field( + default=None, + metadata={"desc": "Index of the choice this token belongs to"}, + ) @dataclass diff --git a/adalflow/tests/test_constrained_parser.py b/adalflow/tests/test_constrained_parser.py new file mode 100644 index 000000000..a45b74d6b --- /dev/null +++ b/adalflow/tests/test_constrained_parser.py @@ -0,0 +1,38 @@ +import unittest + +from adalflow.core.types import TokenLogProb +from adalflow.components.output_parsers.constrained_parser import ( + ConstrainedSelectionParser, +) + + +class ConstrainedSelectionParserTests(unittest.TestCase): + def test_logprob_selection_prefers_high_probability_option(self): + parser = ConstrainedSelectionParser(options=["A", "B"]) + logprobs = [ + [ + TokenLogProb(token="A", logprob=-0.1, choice_index=0), + TokenLogProb(token="B", logprob=-5.0, choice_index=0), + ] + ] + + result = parser.call({"response": "ignored", "logprobs": logprobs}) + self.assertEqual("A", result) + + def test_logprob_selection_falls_back_to_text_match(self): + parser = ConstrainedSelectionParser( + options=["Positive", "Negative"], allow_partial_match=False + ) + result = parser.call({"response": "Negative", "logprobs": []}) + self.assertEqual("Negative", result) + + def test_no_match_returns_first_option(self): + parser = ConstrainedSelectionParser( + options=["Yes", "No"], allow_partial_match=False + ) + result = parser.call("Maybe") + self.assertEqual("Yes", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/adalflow/tests/test_logprob_support.py b/adalflow/tests/test_logprob_support.py new file mode 100644 index 000000000..0d29f6e48 --- /dev/null +++ b/adalflow/tests/test_logprob_support.py @@ -0,0 +1,139 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from adalflow.core.generator import Generator +from adalflow.core.model_client import ModelClient +from adalflow.core.types import GeneratorOutput, TokenLogProb, ModelType +from adalflow.components.model_client.openai_client import OpenAIClient + + +class _DummyLogprobClient(ModelClient): + """Minimal ModelClient stub that reports logprob invocations.""" + + def __init__(self): + super().__init__() + self.last_input = None + self.last_model_kwargs = None + + # The following abstract methods are unused in the tests but must be defined. + def convert_inputs_to_api_kwargs(self, input=None, model_kwargs=None, model_type=None): + return {"input": input, "model_kwargs": model_kwargs, "model_type": model_type} + + def call(self, api_kwargs=None, model_type=ModelType.UNDEFINED): + raise AssertionError("call() should not be used in logprob path") + + def parse_chat_completion(self, completion): + return GeneratorOutput(data=completion) + + def track_completion_usage(self, *args, **kwargs): + return None + + def list_models(self): + return [] + + def call_with_logprobs(self, input="", model_kwargs=None, model_type=ModelType.UNDEFINED): + self.last_input = input + self.last_model_kwargs = model_kwargs or {} + + message = SimpleNamespace(content="positive") + choice = SimpleNamespace(message=message, index=0) + + logprobs = [ + [ + TokenLogProb(token="positive", logprob=-0.1, choice_index=0), + TokenLogProb(token="negative", logprob=-5.0, choice_index=0), + ] + ] + + completion = SimpleNamespace(choices=[choice]) + return completion, logprobs + + +class _DummyNoLogprobClient(ModelClient): + """ModelClient stub that does not implement logprob support.""" + + def convert_inputs_to_api_kwargs(self, input=None, model_kwargs=None, model_type=None): + return {"input": input, "model_kwargs": model_kwargs, "model_type": model_type} + + def call(self, api_kwargs=None, model_type=ModelType.UNDEFINED): + return "fallback" + + def parse_chat_completion(self, completion): + return GeneratorOutput(data="fallback") + + def track_completion_usage(self, *args, **kwargs): + return None + + def list_models(self): + return [] + + def call_with_logprobs(self, *args, **kwargs): + raise NotImplementedError + + +class LogprobIntegrationTests(unittest.TestCase): + def test_select_from_options_uses_logprob_path_when_available(self): + client = _DummyLogprobClient() + generator = Generator( + model_client=client, + model_kwargs={"model": "fake-model", "temperature": 0.2}, + template="Classify sentiment: {{input_str}}", + ) + + with patch.object(generator, "call", wraps=generator.call) as call_mock: + result = generator.select_from_options( + options=["positive", "negative"], + prompt_kwargs={"input_str": "I love this!"}, + model_kwargs={"max_tokens": 42}, + ) + + self.assertEqual("positive", result) + call_mock.assert_not_called() + + # The logprob client should see the merged prompt and kwargs. + self.assertIn("positive", client.last_input) + self.assertEqual(client.last_model_kwargs["model"], "fake-model") + self.assertEqual(client.last_model_kwargs["temperature"], 0.2) + self.assertEqual(client.last_model_kwargs["max_tokens"], 42) + + # Template must be restored after the call. + self.assertEqual(generator.template, "Classify sentiment: {{input_str}}") + + def test_select_from_options_falls_back_when_logprob_missing(self): + client = _DummyNoLogprobClient() + generator = Generator(model_client=client, template="Pick: {{input_str}}") + + with patch.object( + generator, + "call", + return_value=GeneratorOutput(data="negative"), + ) as call_mock: + result = generator.select_from_options( + options=["positive", "negative"], + prompt_kwargs={"input_str": "Feels off"}, + ) + + self.assertEqual("negative", result) + call_mock.assert_called_once() + + +class OpenAIClientLogprobTests(unittest.TestCase): + def test_extract_logprobs_attaches_choice_index(self): + client = OpenAIClient.__new__(OpenAIClient) + + token_entries = [ + SimpleNamespace(token="positive", logprob=-0.1), + SimpleNamespace(token="!", logprob=-0.05), + ] + choice = SimpleNamespace( + index=3, + logprobs=SimpleNamespace(content=token_entries), + ) + completion = SimpleNamespace(choices=[choice]) + + extracted = client._extract_logprobs(completion) + self.assertEqual(1, len(extracted)) + self.assertEqual(2, len(extracted[0])) + self.assertTrue(all(token.choice_index == 3 for token in extracted[0])) +