|
| 1 | +from typing import Any, Optional, Union |
| 2 | + |
| 3 | +from evals.solvers.solver import Solver, SolverResult |
| 4 | +from evals.task_state import TaskState, Message |
| 5 | +from evals.record import record_sampling |
| 6 | +from evals.utils.api_utils import request_with_timeout |
| 7 | + |
| 8 | +import anthropic |
| 9 | +from anthropic import Anthropic |
| 10 | +from anthropic.types import ContentBlock, MessageParam, Usage |
| 11 | +import backoff |
| 12 | + |
| 13 | +oai_to_anthropic_role = { |
| 14 | + "system": "user", |
| 15 | + "user": "user", |
| 16 | + "assistant": "assistant", |
| 17 | +} |
| 18 | + |
| 19 | + |
| 20 | +class AnthropicSolver(Solver): |
| 21 | + """ |
| 22 | + A solver class that uses the Anthropic API for textual chat-based tasks. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + model_name: str, |
| 28 | + max_tokens: int = 512, |
| 29 | + postprocessors: list[str] = [], |
| 30 | + extra_options: Optional[dict] = {}, |
| 31 | + registry: Any = None, |
| 32 | + ): |
| 33 | + super().__init__(postprocessors=postprocessors) |
| 34 | + # https://docs.anthropic.com/claude/docs/models-overview#model-comparison |
| 35 | + self.model_name = model_name |
| 36 | + self.max_tokens = max_tokens |
| 37 | + self.extra_options = extra_options |
| 38 | + |
| 39 | + def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: |
| 40 | + """ |
| 41 | + Solve the task using the Anthropic API |
| 42 | + """ |
| 43 | + orig_msgs = task_state.messages |
| 44 | + anth_msgs = self._convert_msgs_to_anthropic_format(task_state.messages) |
| 45 | + |
| 46 | + # TODO: handle context length limit; possible once anthropic tokenizer is available |
| 47 | + |
| 48 | + # calls client.messages.create, but is wrapped with backoff retrying decorator |
| 49 | + response = anthropic_create_retrying( |
| 50 | + client=Anthropic(max_retries=0), # we take care of retries ourselves |
| 51 | + model=self.model_name, |
| 52 | + system=task_state.task_description, |
| 53 | + messages=anth_msgs, |
| 54 | + max_tokens=self.max_tokens, # required kwarg for messages.create |
| 55 | + **{**kwargs, **self.extra_options}, |
| 56 | + ) |
| 57 | + solver_result = SolverResult( |
| 58 | + output=response.content[0].text, raw_completion_result=response.content |
| 59 | + ) |
| 60 | + |
| 61 | + # for logging purposes: prepend the task desc to the orig msgs as a system message |
| 62 | + orig_msgs.insert( |
| 63 | + 0, Message(role="system", content=task_state.task_description).to_dict() |
| 64 | + ) |
| 65 | + record_sampling( |
| 66 | + prompt=orig_msgs, # original message format, supported by our logviz |
| 67 | + sampled=[solver_result.output], |
| 68 | + model=self.model_name, |
| 69 | + usage=anth_to_openai_usage(response.usage), |
| 70 | + ) |
| 71 | + return solver_result |
| 72 | + |
| 73 | + @property |
| 74 | + def name(self) -> str: |
| 75 | + return self.model_name |
| 76 | + |
| 77 | + @property |
| 78 | + def model_version(self) -> Union[str, dict]: |
| 79 | + """ |
| 80 | + For the moment, Anthropic does not use aliases, |
| 81 | + so model_version is the same as model_name. |
| 82 | + """ |
| 83 | + return self.model_name |
| 84 | + |
| 85 | + @staticmethod |
| 86 | + def _convert_msgs_to_anthropic_format(msgs: list[Message]) -> list[MessageParam]: |
| 87 | + """ |
| 88 | + Anthropic API requires that the message list has |
| 89 | + - Roles as 'user' or 'assistant' |
| 90 | + - Alternating 'user' and 'assistant' messages |
| 91 | +
|
| 92 | + Note: the top-level system prompt is handled separately and should not be |
| 93 | + included in the messages list. |
| 94 | + """ |
| 95 | + # enforce valid roles; convert to Anthropic message type |
| 96 | + anth_msgs = [ |
| 97 | + MessageParam( |
| 98 | + role=oai_to_anthropic_role[msg.role], |
| 99 | + content=[ContentBlock(text=msg.content, type="text")], |
| 100 | + ) |
| 101 | + for msg in msgs |
| 102 | + ] |
| 103 | + # enforce alternating roles by merging consecutive messages with the same role |
| 104 | + # e.g. [user1, user2, assistant1, user3] -> [user12, assistant1, user3] |
| 105 | + alt_msgs = [] |
| 106 | + for msg in anth_msgs: |
| 107 | + if len(alt_msgs) > 0 and msg["role"] == alt_msgs[-1]["role"]: |
| 108 | + # Merge consecutive messages from the same role |
| 109 | + alt_msgs[-1]["content"].extend(msg["content"]) |
| 110 | + else: |
| 111 | + alt_msgs.append(msg) |
| 112 | + |
| 113 | + return alt_msgs |
| 114 | + |
| 115 | + |
| 116 | +@backoff.on_exception( |
| 117 | + wait_gen=backoff.expo, |
| 118 | + exception=( |
| 119 | + anthropic.RateLimitError, |
| 120 | + anthropic.APIConnectionError, |
| 121 | + anthropic.APITimeoutError, |
| 122 | + anthropic.InternalServerError, |
| 123 | + ), |
| 124 | + max_value=60, |
| 125 | + factor=1.5, |
| 126 | +) |
| 127 | +def anthropic_create_retrying(client: Anthropic, *args, **kwargs): |
| 128 | + """ |
| 129 | + Helper function for creating a backoff-retry enabled message request. |
| 130 | + `args` and `kwargs` match what is accepted by `client.messages.create`. |
| 131 | + """ |
| 132 | + result = request_with_timeout(client.messages.create, *args, **kwargs) |
| 133 | + if "error" in result: |
| 134 | + raise Exception(result["error"]) |
| 135 | + return result |
| 136 | + |
| 137 | + |
| 138 | +def anth_to_openai_usage(anth_usage: Usage) -> dict: |
| 139 | + """ |
| 140 | + Processes anthropic Usage object into dict with keys |
| 141 | + that match the OpenAI Usage dict, for logging purposes. |
| 142 | + """ |
| 143 | + # TODO: make this format of dict a dataclass type to be reused througout lib? |
| 144 | + return { |
| 145 | + "completion_tokens": anth_usage.output_tokens, |
| 146 | + "prompt_tokens": anth_usage.input_tokens, |
| 147 | + "total_tokens": anth_usage.input_tokens + anth_usage.output_tokens, |
| 148 | + } |
0 commit comments