|
1 | | -from itertools import cycle |
| 1 | +from typing import Any |
2 | 2 |
|
3 | 3 | from langchain_core.callbacks import ( |
4 | 4 | UsageMetadataCallbackHandler, |
|
12 | 12 | UsageMetadata, |
13 | 13 | add_usage, |
14 | 14 | ) |
| 15 | +from langchain_core.outputs import ChatResult |
15 | 16 |
|
16 | 17 | usage1 = UsageMetadata( |
17 | 18 | input_tokens=1, |
|
45 | 46 | ] |
46 | 47 |
|
47 | 48 |
|
| 49 | +class FakeChatModelWithResponseMetadata(GenericFakeChatModel): |
| 50 | + model_name: str |
| 51 | + |
| 52 | + def _generate(self, *args: Any, **kwargs: Any) -> ChatResult: |
| 53 | + result = super()._generate(*args, **kwargs) |
| 54 | + result.generations[0].message.response_metadata = { |
| 55 | + "model_name": self.model_name |
| 56 | + } |
| 57 | + return result |
| 58 | + |
| 59 | + |
48 | 60 | def test_usage_callback() -> None: |
49 | | - llm = GenericFakeChatModel(messages=cycle(messages)) |
| 61 | + llm = FakeChatModelWithResponseMetadata( |
| 62 | + messages=iter(messages), model_name="test_model" |
| 63 | + ) |
50 | 64 |
|
51 | 65 | # Test context manager |
52 | 66 | with get_usage_metadata_callback() as cb: |
53 | 67 | _ = llm.invoke("Message 1") |
54 | 68 | _ = llm.invoke("Message 2") |
55 | 69 | total_1_2 = add_usage(usage1, usage2) |
56 | | - assert cb.usage_metadata == total_1_2 |
| 70 | + assert cb.usage_metadata == {"test_model": total_1_2} |
57 | 71 | _ = llm.invoke("Message 3") |
58 | 72 | _ = llm.invoke("Message 4") |
59 | 73 | total_3_4 = add_usage(usage3, usage4) |
60 | | - assert cb.usage_metadata == add_usage(total_1_2, total_3_4) |
| 74 | + assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)} |
61 | 75 |
|
62 | 76 | # Test via config |
| 77 | + llm = FakeChatModelWithResponseMetadata( |
| 78 | + messages=iter(messages[:2]), model_name="test_model" |
| 79 | + ) |
63 | 80 | callback = UsageMetadataCallbackHandler() |
64 | 81 | _ = llm.batch(["Message 1", "Message 2"], config={"callbacks": [callback]}) |
65 | | - assert callback.usage_metadata == total_1_2 |
| 82 | + assert callback.usage_metadata == {"test_model": total_1_2} |
| 83 | + |
| 84 | + # Test multiple models |
| 85 | + llm_1 = FakeChatModelWithResponseMetadata( |
| 86 | + messages=iter(messages[:2]), model_name="test_model_1" |
| 87 | + ) |
| 88 | + llm_2 = FakeChatModelWithResponseMetadata( |
| 89 | + messages=iter(messages[2:4]), model_name="test_model_2" |
| 90 | + ) |
| 91 | + callback = UsageMetadataCallbackHandler() |
| 92 | + _ = llm_1.batch(["Message 1", "Message 2"], config={"callbacks": [callback]}) |
| 93 | + _ = llm_2.batch(["Message 3", "Message 4"], config={"callbacks": [callback]}) |
| 94 | + assert callback.usage_metadata == { |
| 95 | + "test_model_1": total_1_2, |
| 96 | + "test_model_2": total_3_4, |
| 97 | + } |
66 | 98 |
|
67 | 99 |
|
68 | 100 | async def test_usage_callback_async() -> None: |
69 | | - llm = GenericFakeChatModel(messages=cycle(messages)) |
| 101 | + llm = FakeChatModelWithResponseMetadata( |
| 102 | + messages=iter(messages), model_name="test_model" |
| 103 | + ) |
70 | 104 |
|
71 | 105 | # Test context manager |
72 | 106 | with get_usage_metadata_callback() as cb: |
73 | 107 | _ = await llm.ainvoke("Message 1") |
74 | 108 | _ = await llm.ainvoke("Message 2") |
75 | 109 | total_1_2 = add_usage(usage1, usage2) |
76 | | - assert cb.usage_metadata == total_1_2 |
| 110 | + assert cb.usage_metadata == {"test_model": total_1_2} |
77 | 111 | _ = await llm.ainvoke("Message 3") |
78 | 112 | _ = await llm.ainvoke("Message 4") |
79 | 113 | total_3_4 = add_usage(usage3, usage4) |
80 | | - assert cb.usage_metadata == add_usage(total_1_2, total_3_4) |
| 114 | + assert cb.usage_metadata == {"test_model": add_usage(total_1_2, total_3_4)} |
81 | 115 |
|
82 | 116 | # Test via config |
| 117 | + llm = FakeChatModelWithResponseMetadata( |
| 118 | + messages=iter(messages[:2]), model_name="test_model" |
| 119 | + ) |
83 | 120 | callback = UsageMetadataCallbackHandler() |
84 | 121 | _ = await llm.abatch(["Message 1", "Message 2"], config={"callbacks": [callback]}) |
85 | | - assert callback.usage_metadata == total_1_2 |
| 122 | + assert callback.usage_metadata == {"test_model": total_1_2} |
0 commit comments