|
1 | 1 | import unittest |
2 | 2 | from datetime import date, datetime, timedelta |
3 | 3 | from decimal import Decimal |
| 4 | +from typing import Optional |
4 | 5 | from uuid import UUID |
5 | 6 |
|
6 | 7 | import six |
7 | 8 | from dateutil.tz import tzutc |
| 9 | +from pydantic import BaseModel |
| 10 | +from pydantic.v1 import BaseModel as BaseModelV1 |
8 | 11 |
|
9 | 12 | from posthog import utils |
10 | 13 |
|
@@ -81,6 +84,32 @@ def test_remove_slash(self): |
81 | 84 | self.assertEqual("http://posthog.io", utils.remove_trailing_slash("http://posthog.io/")) |
82 | 85 | self.assertEqual("http://posthog.io", utils.remove_trailing_slash("http://posthog.io")) |
83 | 86 |
|
| 87 | + def test_clean_pydantic(self): |
| 88 | + class ModelV2(BaseModel): |
| 89 | + foo: str |
| 90 | + bar: int |
| 91 | + baz: Optional[str] = None |
| 92 | + |
| 93 | + class ModelV1(BaseModelV1): |
| 94 | + foo: int |
| 95 | + bar: str |
| 96 | + |
| 97 | + class NestedModel(BaseModel): |
| 98 | + foo: ModelV2 |
| 99 | + |
| 100 | + self.assertEqual(utils.clean(ModelV2(foo="1", bar=2)), {"foo": "1", "bar": 2, "baz": None}) |
| 101 | + self.assertEqual(utils.clean(ModelV1(foo=1, bar="2")), {"foo": 1, "bar": "2"}) |
| 102 | + self.assertEqual( |
| 103 | + utils.clean(NestedModel(foo=ModelV2(foo="1", bar=2, baz="3"))), {"foo": {"foo": "1", "bar": 2, "baz": "3"}} |
| 104 | + ) |
| 105 | + |
| 106 | + class Dummy: |
| 107 | + def model_dump(self, required_param): |
| 108 | + pass |
| 109 | + |
| 110 | + # Skips a class with a defined non-Pydantic `model_dump` method. |
| 111 | + self.assertEqual(utils.clean({"test": Dummy()}), {}) |
| 112 | + |
84 | 113 |
|
85 | 114 | class TestSizeLimitedDict(unittest.TestCase): |
86 | 115 | def test_size_limited_dict(self): |
|
0 commit comments