Skip to content

Commit f810b59

Browse files
committed
Fix PostgreSQL IP address serialization error
1 parent 97852fb commit f810b59

File tree

2 files changed

+263
-0
lines changed

2 files changed

+263
-0
lines changed

app/connectors_service/connectors/sources/postgresql/datasource.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
# or more contributor license agreements. Licensed under the Elastic License 2.0;
44
# you may not use this file except in compliance with the Elastic License 2.0.
55
#
6+
from ipaddress import (
7+
IPv4Address,
8+
IPv4Interface,
9+
IPv4Network,
10+
IPv6Address,
11+
IPv6Interface,
12+
IPv6Network,
13+
)
14+
615
from asyncpg.exceptions._base import InternalClientError
716
from connectors_sdk.source import BaseDataSource
817
from connectors_sdk.utils import iso_utc
@@ -141,6 +150,40 @@ async def ping(self):
141150
msg = f"Can't connect to Postgresql on {self.postgresql_client.host}."
142151
raise Exception(msg) from e
143152

153+
def serialize(self, doc):
154+
"""Override base serialize to handle PostgreSQL-specific types like IP addresses.
155+
156+
Args:
157+
doc (Dict): Dictionary to be serialized
158+
159+
Returns:
160+
doc (Dict): Serialized version of dictionary
161+
"""
162+
163+
def _serialize(value):
164+
"""Serialize input value with respect to its datatype.
165+
166+
Args:
167+
value (Any): Value to be serialized
168+
169+
Returns:
170+
value (Any): Serialized version of input value.
171+
"""
172+
match value:
173+
case IPv4Address() | IPv6Address() | IPv4Interface() | IPv6Interface() | IPv4Network() | IPv6Network():
174+
return str(value)
175+
case list() | tuple():
176+
return [_serialize(item) for item in value]
177+
case dict():
178+
return {k: _serialize(v) for k, v in value.items()}
179+
case _:
180+
return value
181+
182+
for key, value in doc.items():
183+
doc[key] = _serialize(value)
184+
185+
return super().serialize(doc)
186+
144187
def row2doc(self, row, doc_id, table, timestamp):
145188
row.update(
146189
{

app/connectors_service/tests/sources/test_postgresql.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@
77

88
import ssl
99
from contextlib import asynccontextmanager
10+
from datetime import datetime
11+
from decimal import Decimal
12+
from ipaddress import (
13+
IPv4Address,
14+
IPv4Interface,
15+
IPv4Network,
16+
IPv6Address,
17+
IPv6Interface,
18+
IPv6Network,
19+
)
1020
from unittest.mock import ANY, Mock, patch
1121

1222
import pytest
@@ -629,3 +639,213 @@ async def test_get_table_primary_key():
629639

630640
# Assert
631641
assert primary_keys == ["ids"]
642+
643+
644+
@pytest.mark.asyncio
645+
async def test_serialize_ipv4_address():
646+
"""Test that IPv4Address objects are correctly serialized to strings"""
647+
# Setup
648+
async with create_postgresql_source() as source:
649+
doc = {
650+
"id": 1,
651+
"name": "test",
652+
"ip_address": IPv4Address("192.168.1.1"),
653+
}
654+
655+
# Execute
656+
serialized = source.serialize(doc)
657+
658+
# Assert
659+
assert serialized["ip_address"] == "192.168.1.1"
660+
assert isinstance(serialized["ip_address"], str)
661+
662+
663+
@pytest.mark.asyncio
664+
async def test_serialize_ipv6_address():
665+
"""Test that IPv6Address objects are correctly serialized to strings"""
666+
# Setup
667+
async with create_postgresql_source() as source:
668+
doc = {
669+
"id": 1,
670+
"name": "test",
671+
"ip_address": IPv6Address("2001:db8::1"),
672+
}
673+
674+
# Execute
675+
serialized = source.serialize(doc)
676+
677+
# Assert
678+
assert serialized["ip_address"] == "2001:db8::1"
679+
assert isinstance(serialized["ip_address"], str)
680+
681+
682+
@pytest.mark.asyncio
683+
async def test_serialize_ip_address_in_list():
684+
"""Test that IP addresses nested in lists are correctly serialized"""
685+
# Setup
686+
async with create_postgresql_source() as source:
687+
doc = {
688+
"id": 1,
689+
"ip_addresses": [
690+
IPv4Address("10.0.0.1"),
691+
IPv4Address("10.0.0.2"),
692+
IPv6Address("fe80::1"),
693+
],
694+
}
695+
696+
# Execute
697+
serialized = source.serialize(doc)
698+
699+
# Assert
700+
assert serialized["ip_addresses"] == ["10.0.0.1", "10.0.0.2", "fe80::1"]
701+
assert all(isinstance(ip, str) for ip in serialized["ip_addresses"])
702+
703+
704+
@pytest.mark.asyncio
705+
async def test_serialize_ip_address_in_nested_dict():
706+
"""Test that IP addresses nested in dictionaries are correctly serialized"""
707+
# Setup
708+
async with create_postgresql_source() as source:
709+
doc = {
710+
"id": 1,
711+
"connection": {
712+
"source_ip": IPv4Address("192.168.1.100"),
713+
"dest_ip": IPv4Address("8.8.8.8"),
714+
},
715+
}
716+
717+
# Execute
718+
serialized = source.serialize(doc)
719+
720+
# Assert
721+
assert serialized["connection"]["source_ip"] == "192.168.1.100"
722+
assert serialized["connection"]["dest_ip"] == "8.8.8.8"
723+
assert isinstance(serialized["connection"]["source_ip"], str)
724+
assert isinstance(serialized["connection"]["dest_ip"], str)
725+
726+
727+
@pytest.mark.asyncio
728+
async def test_serialize_mixed_types_with_ip():
729+
"""Test that serialization handles mixed data types including IP addresses"""
730+
# Setup
731+
async with create_postgresql_source() as source:
732+
timestamp = datetime(2023, 1, 15, 10, 30, 45)
733+
doc = {
734+
"id": 1,
735+
"ip_address": IPv4Address("10.30.0.9"),
736+
"timestamp": timestamp,
737+
"price": Decimal("99.99"),
738+
"data": bytes("test", "utf-8"),
739+
}
740+
741+
# Execute
742+
serialized = source.serialize(doc)
743+
744+
# Assert
745+
assert serialized["ip_address"] == "10.30.0.9"
746+
assert isinstance(serialized["ip_address"], str)
747+
assert serialized["timestamp"] == timestamp.isoformat()
748+
assert serialized["price"] == 99.99
749+
assert isinstance(serialized["price"], float)
750+
assert serialized["data"] == "test"
751+
assert isinstance(serialized["data"], str)
752+
753+
754+
@pytest.mark.asyncio
755+
async def test_serialize_ipv4_network():
756+
"""Test that IPv4Network objects (CIDR type) are correctly serialized to strings"""
757+
# Setup
758+
async with create_postgresql_source() as source:
759+
doc = {
760+
"id": 1,
761+
"network": IPv4Network("192.168.0.0/24"),
762+
}
763+
764+
# Execute
765+
serialized = source.serialize(doc)
766+
767+
# Assert
768+
assert serialized["network"] == "192.168.0.0/24"
769+
assert isinstance(serialized["network"], str)
770+
771+
772+
@pytest.mark.asyncio
773+
async def test_serialize_ipv6_network():
774+
"""Test that IPv6Network objects (CIDR type) are correctly serialized to strings"""
775+
# Setup
776+
async with create_postgresql_source() as source:
777+
doc = {
778+
"id": 1,
779+
"network": IPv6Network("2001:db8::/32"),
780+
}
781+
782+
# Execute
783+
serialized = source.serialize(doc)
784+
785+
# Assert
786+
assert serialized["network"] == "2001:db8::/32"
787+
assert isinstance(serialized["network"], str)
788+
789+
790+
@pytest.mark.asyncio
791+
async def test_serialize_ipv4_interface():
792+
"""Test that IPv4Interface objects (INET type) are correctly serialized to strings"""
793+
# Setup
794+
async with create_postgresql_source() as source:
795+
doc = {
796+
"id": 1,
797+
"interface": IPv4Interface("192.168.1.1/24"),
798+
}
799+
800+
# Execute
801+
serialized = source.serialize(doc)
802+
803+
# Assert
804+
assert serialized["interface"] == "192.168.1.1/24"
805+
assert isinstance(serialized["interface"], str)
806+
807+
808+
@pytest.mark.asyncio
809+
async def test_serialize_ipv6_interface():
810+
"""Test that IPv6Interface objects (INET type) are correctly serialized to strings"""
811+
# Setup
812+
async with create_postgresql_source() as source:
813+
doc = {
814+
"id": 1,
815+
"interface": IPv6Interface("2001:db8::1/64"),
816+
}
817+
818+
# Execute
819+
serialized = source.serialize(doc)
820+
821+
# Assert
822+
assert serialized["interface"] == "2001:db8::1/64"
823+
assert isinstance(serialized["interface"], str)
824+
825+
826+
@pytest.mark.asyncio
827+
async def test_serialize_all_network_types():
828+
"""Test that all PostgreSQL network types are correctly serialized"""
829+
# Setup
830+
async with create_postgresql_source() as source:
831+
doc = {
832+
"id": 1,
833+
"ipv4_addr": IPv4Address("192.168.1.1"),
834+
"ipv6_addr": IPv6Address("2001:db8::1"),
835+
"ipv4_iface": IPv4Interface("10.0.0.1/8"),
836+
"ipv6_iface": IPv6Interface("fe80::1/64"),
837+
"ipv4_net": IPv4Network("172.16.0.0/12"),
838+
"ipv6_net": IPv6Network("2001:db8::/32"),
839+
}
840+
841+
# Execute
842+
serialized = source.serialize(doc)
843+
844+
# Assert
845+
assert serialized["ipv4_addr"] == "192.168.1.1"
846+
assert serialized["ipv6_addr"] == "2001:db8::1"
847+
assert serialized["ipv4_iface"] == "10.0.0.1/8"
848+
assert serialized["ipv6_iface"] == "fe80::1/64"
849+
assert serialized["ipv4_net"] == "172.16.0.0/12"
850+
assert serialized["ipv6_net"] == "2001:db8::/32"
851+
assert all(isinstance(v, str) for k, v in serialized.items() if k != "id")

0 commit comments

Comments
 (0)