Skip to content

Commit 730c850

Browse files
committed
Support serializing dataclasses
1 parent 5a4167d commit 730c850

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

posthog/test/test_utils.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from dataclasses import dataclass
23
from datetime import date, datetime, timedelta
34
from decimal import Decimal
45
from typing import Optional
@@ -56,7 +57,10 @@ def test_clean(self):
5657
self.assertEqual(combined.keys(), pre_clean_keys)
5758

5859
# test UUID separately, as the UUID object doesn't equal its string representation according to Python
59-
self.assertEqual(utils.clean(UUID("12345678123456781234567812345678")), "12345678-1234-5678-1234-567812345678")
60+
self.assertEqual(
61+
utils.clean(UUID("12345678123456781234567812345678")),
62+
"12345678-1234-5678-1234-567812345678",
63+
)
6064

6165
def test_clean_with_dates(self):
6266
dict_with_dates = {
@@ -81,8 +85,12 @@ def test_clean_fn(self):
8185
self.assertEqual(cleaned["fn"], None)
8286

8387
def test_remove_slash(self):
84-
self.assertEqual("http://posthog.io", utils.remove_trailing_slash("http://posthog.io/"))
85-
self.assertEqual("http://posthog.io", utils.remove_trailing_slash("http://posthog.io"))
88+
self.assertEqual(
89+
"http://posthog.io", utils.remove_trailing_slash("http://posthog.io/")
90+
)
91+
self.assertEqual(
92+
"http://posthog.io", utils.remove_trailing_slash("http://posthog.io")
93+
)
8694

8795
def test_clean_pydantic(self):
8896
class ModelV2(BaseModel):
@@ -97,10 +105,13 @@ class ModelV1(BaseModelV1):
97105
class NestedModel(BaseModel):
98106
foo: ModelV2
99107

100-
self.assertEqual(utils.clean(ModelV2(foo="1", bar=2)), {"foo": "1", "bar": 2, "baz": None})
108+
self.assertEqual(
109+
utils.clean(ModelV2(foo="1", bar=2)), {"foo": "1", "bar": 2, "baz": None}
110+
)
101111
self.assertEqual(utils.clean(ModelV1(foo=1, bar="2")), {"foo": 1, "bar": "2"})
102112
self.assertEqual(
103-
utils.clean(NestedModel(foo=ModelV2(foo="1", bar=2, baz="3"))), {"foo": {"foo": "1", "bar": 2, "baz": "3"}}
113+
utils.clean(NestedModel(foo=ModelV2(foo="1", bar=2, baz="3"))),
114+
{"foo": {"foo": "1", "bar": 2, "baz": "3"}},
104115
)
105116

106117
class Dummy:
@@ -110,6 +121,45 @@ def model_dump(self, required_param):
110121
# Skips a class with a defined non-Pydantic `model_dump` method.
111122
self.assertEqual(utils.clean({"test": Dummy()}), {})
112123

124+
def test_clean_dataclass(self):
125+
@dataclass
126+
class InnerDataClass:
127+
inner_foo: str
128+
inner_bar: int
129+
inner_uuid: UUID
130+
inner_date: datetime
131+
132+
@dataclass
133+
class TestDataClass:
134+
foo: str
135+
bar: int
136+
nested: InnerDataClass
137+
138+
self.assertEqual(
139+
utils.clean(
140+
TestDataClass(
141+
foo="1",
142+
bar=2,
143+
nested=InnerDataClass(
144+
inner_foo="3",
145+
inner_bar=4,
146+
inner_uuid=UUID("12345678123456781234567812345678"),
147+
inner_date=datetime(2025, 1, 1),
148+
),
149+
)
150+
),
151+
{
152+
"foo": "1",
153+
"bar": 2,
154+
"nested": {
155+
"inner_foo": "3",
156+
"inner_bar": 4,
157+
"inner_uuid": "12345678-1234-5678-1234-567812345678",
158+
"inner_date": datetime(2025, 1, 1),
159+
},
160+
},
161+
)
162+
113163

114164
class TestSizeLimitedDict(unittest.TestCase):
115165
def test_size_limited_dict(self):

posthog/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numbers
33
import re
44
from collections import defaultdict
5+
from dataclasses import asdict, is_dataclass
56
from datetime import date, datetime, timezone
67
from decimal import Decimal
78
from uuid import UUID
@@ -51,7 +52,9 @@ def clean(item):
5152
return float(item)
5253
if isinstance(item, UUID):
5354
return str(item)
54-
if isinstance(item, (six.string_types, bool, numbers.Number, datetime, date, type(None))):
55+
if isinstance(
56+
item, (six.string_types, bool, numbers.Number, datetime, date, type(None))
57+
):
5558
return item
5659
if isinstance(item, (set, list, tuple)):
5760
return _clean_list(item)
@@ -68,6 +71,8 @@ def clean(item):
6871
pass
6972
if isinstance(item, dict):
7073
return _clean_dict(item)
74+
if is_dataclass(item) and not isinstance(item, type):
75+
return _clean_dataclass(item)
7176
return _coerce_unicode(item)
7277

7378

@@ -90,6 +95,12 @@ def _clean_dict(dict_):
9095
return data
9196

9297

98+
def _clean_dataclass(dataclass_):
99+
data = asdict(dataclass_)
100+
data = _clean_dict(data)
101+
return data
102+
103+
93104
def _coerce_unicode(cmplx):
94105
try:
95106
item = cmplx.decode("utf-8", "strict")

0 commit comments

Comments
 (0)