Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion adalflow/adalflow/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# optional import
from adalflow.utils.lazy_import import safe_import, OptionalPackages

from adalflow.core.types import TokenLogProb

openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])

Expand Down Expand Up @@ -1117,6 +1117,78 @@ async def acall(
else:
raise ValueError(f"model_type {model_type} is not supported")

def _extract_logprobs(self, completion: Any) -> List[List["TokenLogProb"]]:
"""Extract logprobs from OpenAI completion response.

Args:
completion: OpenAI completion response

Returns:
List of token logprobs for each choice
"""
logprobs = []

try:
if hasattr(completion, "choices"):
for choice in completion.choices:
if hasattr(choice, "logprobs") and choice.logprobs:
choice_logprobs = []
for token_logprob in choice.logprobs.content:
choice_logprobs.append(
TokenLogProb(
token=token_logprob.token,
logprob=token_logprob.logprob,
choice_index=getattr(choice, "index", None),
)
)
logprobs.append(choice_logprobs)
else:
logprobs.append([])
except Exception as e:
log.error(f"Failed to extract logprobs: {e}")
raise e

return logprobs

def call_with_logprobs(
self,
input: str = "",
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED
) -> tuple[Any, List[List[TokenLogProb]]]:
"""Call the API with logprobs enabled for constrained generation.

This method uses the traditional chat.completions API which supports logprobs,
instead of the new responses API which doesn't.

Args:
input: The input text to process
model_kwargs: Model parameters
model_type: Type of model call

Returns:
Tuple of (completion, logprobs) where logprobs is a list of token logprobs
"""

log.debug("call with logprobs using chat.completions")

messages = [{"role": "user", "content": str(input)}]

chat_kwargs = {
"model": model_kwargs.get("model", "gpt-3.5-turbo"),
"messages": messages,
"logprobs": True,
"top_logprobs": 5,
"temperature": model_kwargs.get("temperature", 0.1),
"max_tokens": model_kwargs.get("max_tokens", 1000),
}

completion = self.sync_client.chat.completions.create(**chat_kwargs)

logprobs = self._extract_logprobs(completion)

return completion, logprobs

@classmethod
def from_dict(cls: type[T], data: Dict[str, Any]) -> T:
obj = super().from_dict(data)
Expand Down
3 changes: 3 additions & 0 deletions adalflow/adalflow/components/output_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
LIST_OUTPUT_FORMAT,
)
from .dataclass_parser import DataClassParser
from .constrained_parser import ConstrainedSelectionParser, MultiChoiceParser

__all__ = [
"YamlOutputParser",
Expand All @@ -16,4 +17,6 @@
"JSON_OUTPUT_FORMAT",
"LIST_OUTPUT_FORMAT",
"DataClassParser",
"ConstrainedSelectionParser",
"MultiChoiceParser",
]
235 changes: 235 additions & 0 deletions adalflow/adalflow/components/output_parsers/constrained_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""Constrained Selection Parser for forcing models to select from specific options.

This parser uses logprobs to force models to select from a predefined set of options,
similar to the Guidance library's constrained generation capabilities.
"""

import logging
from typing import List, Dict, Any, Optional, Union

from adalflow.core.component import DataComponent
from adalflow.core.types import TokenLogProb

log = logging.getLogger(__name__)


class ConstrainedSelectionParser(DataComponent):
"""Parser that forces model to select from a constrained set of options using logprobs.

This parser is particularly useful for classification tasks where you want to ensure
the model only outputs one of the predefined options.

Example:
parser = ConstrainedSelectionParser(
options=["positive", "negative", "neutral"],
allow_partial_match=True
)

# Use with Generator
generator = Generator(
model_client=model_client,
output_processors=parser,
# ... other args
)
"""

def __init__(
self,
options: List[str],
allow_partial_match: bool = True,
case_sensitive: bool = False,
max_tokens: int = 10,
temperature: float = 0.0,
):
"""Initialize the constrained selection parser.

Args:
options: List of valid options the model can select from
allow_partial_match: Whether to allow partial matches (e.g., "pos" matches "positive")
case_sensitive: Whether option matching is case sensitive
max_tokens: Maximum number of tokens to consider for selection
temperature: Temperature for logprob-based selection (0.0 = deterministic)
"""
super().__init__()

if not options or len(options) == 0:
raise ValueError("Options list cannot be empty")

self.options = options
self.allow_partial_match = allow_partial_match
self.case_sensitive = case_sensitive
self.max_tokens = max_tokens
self.temperature = temperature

# Normalize options for matching
if not case_sensitive:
self.normalized_options = {opt.lower(): opt for opt in options}
else:
self.normalized_options = {opt: opt for opt in options}

def _normalize_text(self, text: str) -> str:
"""Normalize text for matching."""
return text if self.case_sensitive else text.lower()

def _find_best_match(self, text: str) -> Optional[str]:
"""Find the best matching option for the given text."""
normalized_text = self._normalize_text(text.strip())

# Exact match first
if normalized_text in self.normalized_options:
return self.normalized_options[normalized_text]

if not self.allow_partial_match:
return None

# Partial match
for normalized_option, original_option in self.normalized_options.items():
if normalized_option in normalized_text or normalized_text in normalized_option:
return original_option

return None

def _select_from_logprobs(self, logprobs: List[List[TokenLogProb]]) -> Optional[str]:
"""Select the best option using logprobs."""
if not logprobs or len(logprobs) == 0:
return None

# Flatten all logprobs
all_tokens = []
for token_list in logprobs:
all_tokens.extend(token_list)

if not all_tokens:
return None

# Calculate scores for each option
option_scores = {}

for option in self.options:
normalized_option = self._normalize_text(option)
option_words = normalized_option.split()

# score based on token probabilities
score = 0.0
matched_tokens = 0

for i, token in enumerate(all_tokens[:self.max_tokens]):
token_text = self._normalize_text(token.token)

for word in option_words:
if word in token_text or token_text in word:
# Use logprob as score (higher is better)
score += token.logprob
matched_tokens += 1
break

if matched_tokens > 0:
# normalize score by number of matched tokens
option_scores[option] = score / matched_tokens

if not option_scores:
return None

best_option = max(option_scores.items(), key=lambda x: x[1])[0]
return best_option

def call(self, input_data: Union[str, Dict[str, Any]]) -> str:
"""Parse input and return the best matching option.

Args:
input_data: Either a string response or a dict with 'response' and 'logprobs' keys

Returns:
The best matching option from the predefined list
"""
if isinstance(input_data, str):
# Simple text-based matching
result = self._find_best_match(input_data)
if result is None:
log.warning(f"No matching option found for: {input_data}")
return self.options[0] # Return first option as fallback
return result

elif isinstance(input_data, dict):
# Use logprobs if available
if 'logprobs' in input_data and input_data['logprobs']:
result = self._select_from_logprobs(input_data['logprobs'])
if result is not None:
return result

# Fallback to text matching
response_text = input_data.get('response', '')
if response_text:
result = self._find_best_match(response_text)
if result is not None:
return result

log.warning(f"No matching option found in: {input_data}")
return self.options[0]

else:
log.warning(f"Unsupported input type: {type(input_data)}")
return self.options[0]

def get_format_instructions(self) -> str:
"""Get format instructions for the prompt."""
options_str = ", ".join([f'"{opt}"' for opt in self.options])

return f"""You must respond with exactly one of these options: {options_str}

Rules:
- Choose the option that best matches your analysis
- Respond with only the option text, no additional explanation
- If unsure, choose the most appropriate option from the list
- Your response must be one of: {options_str}"""

def _extra_repr(self) -> str:
return f"options={self.options}, allow_partial_match={self.allow_partial_match}, case_sensitive={self.case_sensitive}"


class MultiChoiceParser(ConstrainedSelectionParser):
"""Parser for multiple choice questions with options A, B, C, D, etc."""

def __init__(
self,
num_choices: int = 4,
choice_format: str = "letter", # "letter" or "number"
**kwargs
):
"""Initialize multi-choice parser.

Args:
num_choices: Number of choices (default 4 for A, B, C, D)
choice_format: Format of choices - "letter" (A, B, C, D) or "number" (1, 2, 3, 4)
"""
if choice_format == "letter":
options = [chr(65 + i) for i in range(num_choices)] # A, B, C, D, ...
elif choice_format == "number":
options = [str(i + 1) for i in range(num_choices)] # 1, 2, 3, 4, ...
else:
raise ValueError("choice_format must be 'letter' or 'number'")

super().__init__(options=options, **kwargs)
self.num_choices = num_choices
self.choice_format = choice_format

def get_format_instructions(self) -> str:
"""Get format instructions for multiple choice."""
if self.choice_format == "letter":
choices = [chr(65 + i) for i in range(self.num_choices)]
return f"""You must respond with exactly one letter: {', '.join(choices)}

Rules:
- Choose the letter that corresponds to the best answer
- Respond with only the letter (A, B, C, D, etc.)
- No additional text or explanation
- Your response must be one of: {', '.join(choices)}"""
else:
choices = [str(i + 1) for i in range(self.num_choices)]
return f"""You must respond with exactly one number: {', '.join(choices)}

Rules:
- Choose the number that corresponds to the best answer
- Respond with only the number (1, 2, 3, 4, etc.)
- No additional text or explanation
- Your response must be one of: {', '.join(choices)}"""
Loading
Loading