Skip to content
2 changes: 1 addition & 1 deletion app/connectors_service/NOTICE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7566,7 +7566,7 @@ made under the terms of *both* these licenses.


soupsieve
2.8
2.8.1
MIT License
MIT License

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
# or more contributor license agreements. Licensed under the Elastic License 2.0;
# you may not use this file except in compliance with the Elastic License 2.0.
#
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)

from asyncpg.exceptions._base import InternalClientError
from connectors_sdk.source import BaseDataSource
from connectors_sdk.utils import iso_utc
Expand Down Expand Up @@ -141,6 +150,47 @@ async def ping(self):
msg = f"Can't connect to Postgresql on {self.postgresql_client.host}."
raise Exception(msg) from e

def serialize(self, doc):
"""Override base serialize to handle PostgreSQL-specific types like IP addresses.

Args:
doc (Dict): Dictionary to be serialized

Returns:
doc (Dict): Serialized version of dictionary
"""

def _serialize(value):
"""Serialize input value with respect to its datatype.

Args:
value (Any): Value to be serialized

Returns:
value (Any): Serialized version of input value.
"""
match value:
case (
IPv4Address()
| IPv6Address()
| IPv4Interface()
| IPv6Interface()
| IPv4Network()
| IPv6Network()
):
return str(value)
case list() | tuple():
return [_serialize(item) for item in value]
case dict():
return {k: _serialize(v) for k, v in value.items()}
case _:
return value

for key, value in doc.items():
doc[key] = _serialize(value)

return super().serialize(doc)

def row2doc(self, row, doc_id, table, timestamp):
row.update(
{
Expand Down
220 changes: 220 additions & 0 deletions app/connectors_service/tests/sources/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@

import ssl
from contextlib import asynccontextmanager
from datetime import datetime
from decimal import Decimal
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from unittest.mock import ANY, Mock, patch

import pytest
Expand Down Expand Up @@ -629,3 +639,213 @@ async def test_get_table_primary_key():

# Assert
assert primary_keys == ["ids"]


@pytest.mark.asyncio
async def test_serialize_ipv4_address():
"""Test that IPv4Address objects are correctly serialized to strings"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"name": "test",
"ip_address": IPv4Address("192.168.1.1"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["ip_address"] == "192.168.1.1"
assert isinstance(serialized["ip_address"], str)


@pytest.mark.asyncio
async def test_serialize_ipv6_address():
"""Test that IPv6Address objects are correctly serialized to strings"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"name": "test",
"ip_address": IPv6Address("2001:db8::1"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["ip_address"] == "2001:db8::1"
assert isinstance(serialized["ip_address"], str)


@pytest.mark.asyncio
async def test_serialize_ip_address_in_list():
"""Test that IP addresses nested in lists are correctly serialized"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"ip_addresses": [
IPv4Address("10.0.0.1"),
IPv4Address("10.0.0.2"),
IPv6Address("fe80::1"),
],
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["ip_addresses"] == ["10.0.0.1", "10.0.0.2", "fe80::1"]
assert all(isinstance(ip, str) for ip in serialized["ip_addresses"])


@pytest.mark.asyncio
async def test_serialize_ip_address_in_nested_dict():
"""Test that IP addresses nested in dictionaries are correctly serialized"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"connection": {
"source_ip": IPv4Address("192.168.1.100"),
"dest_ip": IPv4Address("8.8.8.8"),
},
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["connection"]["source_ip"] == "192.168.1.100"
assert serialized["connection"]["dest_ip"] == "8.8.8.8"
assert isinstance(serialized["connection"]["source_ip"], str)
assert isinstance(serialized["connection"]["dest_ip"], str)


@pytest.mark.asyncio
async def test_serialize_mixed_types_with_ip():
"""Test that serialization handles mixed data types including IP addresses"""
# Setup
async with create_postgresql_source() as source:
timestamp = datetime(2023, 1, 15, 10, 30, 45)
doc = {
"id": 1,
"ip_address": IPv4Address("10.30.0.9"),
"timestamp": timestamp,
"price": Decimal("99.99"),
"data": bytes("test", "utf-8"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["ip_address"] == "10.30.0.9"
assert isinstance(serialized["ip_address"], str)
assert serialized["timestamp"] == timestamp.isoformat()
assert serialized["price"] == 99.99
assert isinstance(serialized["price"], float)
assert serialized["data"] == "test"
assert isinstance(serialized["data"], str)


@pytest.mark.asyncio
async def test_serialize_ipv4_network():
"""Test that IPv4Network objects (CIDR type) are correctly serialized to strings"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"network": IPv4Network("192.168.0.0/24"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["network"] == "192.168.0.0/24"
assert isinstance(serialized["network"], str)


@pytest.mark.asyncio
async def test_serialize_ipv6_network():
"""Test that IPv6Network objects (CIDR type) are correctly serialized to strings"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"network": IPv6Network("2001:db8::/32"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["network"] == "2001:db8::/32"
assert isinstance(serialized["network"], str)


@pytest.mark.asyncio
async def test_serialize_ipv4_interface():
"""Test that IPv4Interface objects (INET type) are correctly serialized to strings"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"interface": IPv4Interface("192.168.1.1/24"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["interface"] == "192.168.1.1/24"
assert isinstance(serialized["interface"], str)


@pytest.mark.asyncio
async def test_serialize_ipv6_interface():
"""Test that IPv6Interface objects (INET type) are correctly serialized to strings"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"interface": IPv6Interface("2001:db8::1/64"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["interface"] == "2001:db8::1/64"
assert isinstance(serialized["interface"], str)


@pytest.mark.asyncio
async def test_serialize_all_network_types():
"""Test that all PostgreSQL network types are correctly serialized"""
# Setup
async with create_postgresql_source() as source:
doc = {
"id": 1,
"ipv4_addr": IPv4Address("192.168.1.1"),
"ipv6_addr": IPv6Address("2001:db8::1"),
"ipv4_iface": IPv4Interface("10.0.0.1/8"),
"ipv6_iface": IPv6Interface("fe80::1/64"),
"ipv4_net": IPv4Network("172.16.0.0/12"),
"ipv6_net": IPv6Network("2001:db8::/32"),
}

# Execute
serialized = source.serialize(doc)

# Assert
assert serialized["ipv4_addr"] == "192.168.1.1"
assert serialized["ipv6_addr"] == "2001:db8::1"
assert serialized["ipv4_iface"] == "10.0.0.1/8"
assert serialized["ipv6_iface"] == "fe80::1/64"
assert serialized["ipv4_net"] == "172.16.0.0/12"
assert serialized["ipv6_net"] == "2001:db8::/32"
assert all(isinstance(v, str) for k, v in serialized.items() if k != "id")