|
17 | 17 |
|
18 | 18 | class SubscriptableBaseModel(BaseModel):
|
19 | 19 | def __getitem__(self, key: str) -> Any:
|
20 |
| - return getattr(self, key) |
| 20 | + """ |
| 21 | + >>> msg = Message(role='user') |
| 22 | + >>> msg['role'] |
| 23 | + 'user' |
| 24 | + >>> tool = Tool() |
| 25 | + >>> tool['type'] |
| 26 | + 'function' |
| 27 | + >>> msg = Message(role='user') |
| 28 | + >>> msg['nonexistent'] |
| 29 | + Traceback (most recent call last): |
| 30 | + KeyError: 'nonexistent' |
| 31 | + """ |
| 32 | + if key in self: |
| 33 | + return getattr(self, key) |
| 34 | + |
| 35 | + raise KeyError(key) |
21 | 36 |
|
22 | 37 | def __setitem__(self, key: str, value: Any) -> None:
|
| 38 | + """ |
| 39 | + >>> msg = Message(role='user') |
| 40 | + >>> msg['role'] = 'assistant' |
| 41 | + >>> msg['role'] |
| 42 | + 'assistant' |
| 43 | + """ |
23 | 44 | setattr(self, key, value)
|
24 | 45 |
|
25 | 46 | def __contains__(self, key: str) -> bool:
|
@@ -61,7 +82,20 @@ def __contains__(self, key: str) -> bool:
|
61 | 82 | return False
|
62 | 83 |
|
63 | 84 | def get(self, key: str, default: Any = None) -> Any:
|
64 |
| - return getattr(self, key, default) |
| 85 | + """ |
| 86 | + >>> msg = Message(role='user') |
| 87 | + >>> msg.get('role') |
| 88 | + 'user' |
| 89 | + >>> tool = Tool() |
| 90 | + >>> tool.get('type') |
| 91 | + 'function' |
| 92 | + >>> msg = Message(role='user') |
| 93 | + >>> msg.get('nonexistent') |
| 94 | + >>> msg = Message(role='user') |
| 95 | + >>> msg.get('nonexistent', 'default') |
| 96 | + 'default' |
| 97 | + """ |
| 98 | + return self[key] if key in self else default |
65 | 99 |
|
66 | 100 |
|
67 | 101 | class Options(SubscriptableBaseModel):
|
|
0 commit comments