Skip to content

Commit b8d2904

Browse files
nathan-gageclaudeDouweM
authored
Make ModelRetry hashable (#3394)
Co-authored-by: Claude <[email protected]> Co-authored-by: Douwe Maan <[email protected]>
1 parent a980aa0 commit b8d2904

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def __init__(self, message: str):
4444
def __eq__(self, other: Any) -> bool:
4545
return isinstance(other, self.__class__) and other.message == self.message
4646

47+
def __hash__(self) -> int:
48+
return hash((self.__class__, self.message))
49+
4750
@classmethod
4851
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> core_schema.CoreSchema:
4952
"""Pydantic core schema to allow `ModelRetry` to be (de)serialized."""

tests/test_exceptions.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Tests for exception classes."""
2+
3+
from collections.abc import Callable
4+
from typing import Any
5+
6+
import pytest
7+
8+
from pydantic_ai import ModelRetry
9+
from pydantic_ai.exceptions import (
10+
AgentRunError,
11+
ApprovalRequired,
12+
CallDeferred,
13+
IncompleteToolCall,
14+
ModelHTTPError,
15+
UnexpectedModelBehavior,
16+
UsageLimitExceeded,
17+
UserError,
18+
)
19+
20+
21+
@pytest.mark.parametrize(
22+
'exc_factory',
23+
[
24+
lambda: ModelRetry('test'),
25+
lambda: CallDeferred(),
26+
lambda: ApprovalRequired(),
27+
lambda: UserError('test'),
28+
lambda: AgentRunError('test'),
29+
lambda: UnexpectedModelBehavior('test'),
30+
lambda: UsageLimitExceeded('test'),
31+
lambda: ModelHTTPError(500, 'model'),
32+
lambda: IncompleteToolCall('test'),
33+
],
34+
ids=[
35+
'ModelRetry',
36+
'CallDeferred',
37+
'ApprovalRequired',
38+
'UserError',
39+
'AgentRunError',
40+
'UnexpectedModelBehavior',
41+
'UsageLimitExceeded',
42+
'ModelHTTPError',
43+
'IncompleteToolCall',
44+
],
45+
)
46+
def test_exceptions_hashable(exc_factory: Callable[[], Any]):
47+
"""Test that all exception classes are hashable and usable as keys."""
48+
exc = exc_factory()
49+
50+
# Does not raise TypeError
51+
_ = hash(exc)
52+
53+
# Can be used in sets and dicts
54+
s = {exc}
55+
d = {exc: 'value'}
56+
57+
assert exc in s
58+
assert d[exc] == 'value'

0 commit comments

Comments
 (0)