diff --git a/connectors/sources/postgresql.py b/connectors/sources/postgresql.py index 28be15d46..03ed910d2 100644 --- a/connectors/sources/postgresql.py +++ b/connectors/sources/postgresql.py @@ -7,10 +7,29 @@ import ssl from functools import cached_property, partial +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, +) from urllib.parse import quote +from uuid import UUID import fastjsonschema from asyncpg.exceptions._base import InternalClientError +from asyncpg.types import ( + BitString, + Box, + Circle, + Line, + LineSegment, + Path, + Point, + Polygon, +) from fastjsonschema import JsonSchemaValueException from sqlalchemy import text from sqlalchemy.exc import ProgrammingError @@ -510,6 +529,76 @@ async def ping(self): msg = f"Can't connect to Postgresql on {self.postgresql_client.host}." raise Exception(msg) from e + def serialize(self, doc): + """Override base serialize to handle PostgreSQL-specific types. + + PostgreSQL connector uses asyncpg which returns special Python objects for certain + PostgreSQL data types that need to be serialized to strings: + - Network types (INET, CIDR) -> ipaddress module objects + - UUID type -> uuid.UUID objects + - Geometric types (POINT, LINE, POLYGON, etc.) -> asyncpg.types objects + - BitString type (BIT, VARBIT) -> asyncpg.types.BitString objects + + Args: + doc (Dict): Dictionary to be serialized + + Returns: + doc (Dict): Serialized version of dictionary + """ + + def _serialize(value): + """Serialize input value with respect to its datatype. + + Args: + value (Any): Value to be serialized + + Returns: + value (Any): Serialized version of input value. + """ + match value: + case ( + IPv4Address() + | IPv6Address() + | IPv4Interface() + | IPv6Interface() + | IPv4Network() + | IPv6Network() + ): + return str(value) + case UUID(): + return str(value) + case Point(): + return f"({value.x}, {value.y})" + case LineSegment(): + return ( + f"[({value.p1.x}, {value.p1.y}), ({value.p2.x}, {value.p2.y})]" + ) + case Box(): + return f"[({value.high.x}, {value.high.y}), ({value.low.x}, {value.low.y})]" + case Polygon(): + # Polygon inherits from Path, so check it first + coords = [(p.x, p.y) for p in value.points] + return str(coords) + case Path(): + coords = [(p.x, p.y) for p in value.points] + status = "closed" if value.is_closed else "open" + return f"{status} {str(coords)}" + case Line() | Circle(): + return str(value) + case BitString(): + return value.as_string() + case list() | tuple(): + return [_serialize(item) for item in value] + case dict(): + return {k: _serialize(v) for k, v in value.items()} + case _: + return value + + for key, value in doc.items(): + doc[key] = _serialize(value) + + return super().serialize(doc) + def row2doc(self, row, doc_id, table, timestamp): row.update( { diff --git a/tests/sources/fixtures/postgresql/fixture.py b/tests/sources/fixtures/postgresql/fixture.py index ced41e8e9..994a110aa 100644 --- a/tests/sources/fixtures/postgresql/fixture.py +++ b/tests/sources/fixtures/postgresql/fixture.py @@ -9,6 +9,16 @@ import random import asyncpg +from asyncpg.types import ( + BitString, + Box, + Circle, + Line, + LineSegment, + Path, + Point, + Polygon, +) from tests.commons import WeightedFakeProvider @@ -36,9 +46,12 @@ event_loop = asyncio.get_event_loop() +# Number of test rows in special_types table for serialization testing +SPECIAL_TYPES_TEST_ROWS = 3 + def get_num_docs(): - print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE)) + print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE) + SPECIAL_TYPES_TEST_ROWS) async def load(): @@ -81,6 +94,104 @@ async def inject_lines(table, connect, lines): inserted += batch_size print(f"Inserting batch #{batch} of {batch_size} documents.") + async def create_special_types_table(): + """Create a table with PostgreSQL special types that require serialization.""" + connect = await asyncpg.connect(CONNECTION_STRING) + + print("Creating special_types table for serialization testing...") + create_table_sql = """ + CREATE TABLE IF NOT EXISTS special_types ( + id SERIAL PRIMARY KEY, + ip_inet INET, + ip_cidr CIDR, + uuid_col UUID, + point_col POINT, + line_col LINE, + lseg_col LSEG, + box_col BOX, + path_col PATH, + polygon_col POLYGON, + circle_col CIRCLE, + bit_col BIT(8), + varbit_col VARBIT(16), + inet_array INET[], + uuid_array UUID[] + ) + """ + await connect.execute(create_table_sql) + + print("Inserting special type test data...") + insert_sql = """ + INSERT INTO special_types ( + ip_inet, ip_cidr, uuid_col, + point_col, line_col, lseg_col, box_col, path_col, polygon_col, circle_col, + bit_col, varbit_col, + inet_array, uuid_array + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + """ + + test_rows = [ + ( + "192.168.1.1", + "10.0.0.0/8", + "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", + Point(1.5, 2.5), + Line(1, -1, 0), + LineSegment((0, 0), (1, 1)), + Box((2, 2), (0, 0)), + Path((0, 0), (1, 1), (2, 0)), + Polygon((0, 0), (1, 0), (1, 1), (0, 1)), + Circle((0, 0), 5), + BitString("10101010"), + BitString("1101"), + ["192.168.1.1", "10.0.0.1"], + [ + "550e8400-e29b-41d4-a716-446655440000", + "f47ac10b-58cc-4372-a567-0e02b2c3d479", + ], + ), + ( + "2001:db8::1", + "2001:db8::/32", + "123e4567-e89b-12d3-a456-426614174000", + Point(-3.14, 2.71), + Line(2, 3, -6), + LineSegment((-1, -1), (1, 1)), + Box((10, 10), (-10, -10)), + Path((0, 0), (3, 0), (3, 3), (0, 3), is_closed=True), + Polygon((-1, -1), (1, -1), (1, 1), (-1, 1)), + Circle((5, 5), 2.5), + BitString("11110000"), + BitString("101010101010"), + ["::1", "fe80::1"], + ["aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"], + ), + ( + "10.30.0.9/24", + "172.16.0.0/12", + "00000000-0000-0000-0000-000000000000", + Point(0, 0), + Line(0, 1, 0), + LineSegment((5, 5), (10, 10)), + Box((100, 100), (50, 50)), + Path((0, 0), (5, 0), (5, 5)), + Polygon((0, 0), (4, 0), (4, 3), (0, 3)), + Circle((0, 0), 10), + BitString("00000000"), + BitString("1111111111111111"), + ["192.0.2.1", "198.51.100.1", "203.0.113.1"], + [ + "12345678-1234-5678-1234-567812345678", + "87654321-4321-8765-4321-876543218765", + ], + ), + ] + + await connect.executemany(insert_sql, test_rows) + print(f"Inserted {len(test_rows)} rows with special types") + + await connect.close() + async def load_rows(): """N tables of 10001 rows each. each row is ~ 1024*20 bytes""" connect = await asyncpg.connect(CONNECTION_STRING) @@ -93,6 +204,7 @@ async def load_rows(): await create_readonly_user() await load_rows() + await create_special_types_table() async def remove(): diff --git a/tests/sources/test_postgresql.py b/tests/sources/test_postgresql.py index db157106d..39d2eb43c 100644 --- a/tests/sources/test_postgresql.py +++ b/tests/sources/test_postgresql.py @@ -7,9 +7,30 @@ import ssl from contextlib import asynccontextmanager +from datetime import datetime +from decimal import Decimal +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, +) from unittest.mock import ANY, Mock, patch +from uuid import UUID import pytest +from asyncpg.types import ( + BitString, + Box, + Circle, + Line, + LineSegment, + Path, + Point, + Polygon, +) from freezegun import freeze_time from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine @@ -630,3 +651,452 @@ async def test_get_table_primary_key(): # Assert assert primary_keys == ["ids"] + + +@pytest.mark.asyncio +async def test_serialize_ipv4_address(): + """Test that IPv4Address objects are correctly serialized to strings""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "name": "test", + "ip_address": IPv4Address("192.168.1.1"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["ip_address"] == "192.168.1.1" + + +@pytest.mark.asyncio +async def test_serialize_ipv6_address(): + """Test that IPv6Address objects are correctly serialized to strings""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "name": "test", + "ip_address": IPv6Address("2001:db8::1"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["ip_address"] == "2001:db8::1" + + +@pytest.mark.asyncio +async def test_serialize_ip_address_in_list(): + """Test that IP addresses nested in lists are correctly serialized""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "ip_addresses": [ + IPv4Address("10.0.0.1"), + IPv4Address("10.0.0.2"), + IPv6Address("fe80::1"), + ], + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["ip_addresses"] == ["10.0.0.1", "10.0.0.2", "fe80::1"] + assert all(isinstance(ip, str) for ip in serialized["ip_addresses"]) + + +@pytest.mark.asyncio +async def test_serialize_ip_address_in_nested_dict(): + """Test that IP addresses nested in dictionaries are correctly serialized""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "connection": { + "source_ip": IPv4Address("192.168.1.100"), + "dest_ip": IPv4Address("8.8.8.8"), + }, + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["connection"]["source_ip"] == "192.168.1.100" + assert serialized["connection"]["dest_ip"] == "8.8.8.8" + + +@pytest.mark.asyncio +async def test_serialize_mixed_types_with_ip(): + """Test that serialization handles mixed data types including IP addresses""" + # Setup + async with create_postgresql_source() as source: + timestamp = datetime(2023, 1, 15, 10, 30, 45) + doc = { + "id": 1, + "ip_address": IPv4Address("10.30.0.9"), + "timestamp": timestamp, + "price": Decimal("99.99"), + "data": bytes("test", "utf-8"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["ip_address"] == "10.30.0.9" + assert serialized["timestamp"] == timestamp.isoformat() + assert serialized["price"] == 99.99 + assert serialized["data"] == "test" + + +@pytest.mark.asyncio +async def test_serialize_ipv4_network(): + """Test that IPv4Network objects (CIDR type) are correctly serialized to strings""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "network": IPv4Network("192.168.0.0/24"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["network"] == "192.168.0.0/24" + + +@pytest.mark.asyncio +async def test_serialize_ipv6_network(): + """Test that IPv6Network objects (CIDR type) are correctly serialized to strings""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "network": IPv6Network("2001:db8::/32"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["network"] == "2001:db8::/32" + + +@pytest.mark.asyncio +async def test_serialize_ipv4_interface(): + """Test that IPv4Interface objects (INET type) are correctly serialized to strings""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "interface": IPv4Interface("192.168.1.1/24"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["interface"] == "192.168.1.1/24" + + +@pytest.mark.asyncio +async def test_serialize_ipv6_interface(): + """Test that IPv6Interface objects (INET type) are correctly serialized to strings""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "interface": IPv6Interface("2001:db8::1/64"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["interface"] == "2001:db8::1/64" + + +@pytest.mark.asyncio +async def test_serialize_all_network_types(): + """Test that all PostgreSQL network types are correctly serialized""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "ipv4_addr": IPv4Address("192.168.1.1"), + "ipv6_addr": IPv6Address("2001:db8::1"), + "ipv4_iface": IPv4Interface("10.0.0.1/8"), + "ipv6_iface": IPv6Interface("fe80::1/64"), + "ipv4_net": IPv4Network("172.16.0.0/12"), + "ipv6_net": IPv6Network("2001:db8::/32"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["ipv4_addr"] == "192.168.1.1" + assert serialized["ipv6_addr"] == "2001:db8::1" + assert serialized["ipv4_iface"] == "10.0.0.1/8" + assert serialized["ipv6_iface"] == "fe80::1/64" + assert serialized["ipv4_net"] == "172.16.0.0/12" + assert serialized["ipv6_net"] == "2001:db8::/32" + assert all(isinstance(v, str) for k, v in serialized.items() if k != "id") + + +@pytest.mark.asyncio +async def test_serialize_uuid(): + """Test that UUID objects are correctly serialized to strings""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "uuid_col": UUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["uuid_col"] == "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11" + + +@pytest.mark.asyncio +async def test_serialize_uuid_in_list(): + """Test that UUIDs nested in lists are correctly serialized""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "uuid_array": [ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479"), + ], + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["uuid_array"] == [ + "550e8400-e29b-41d4-a716-446655440000", + "f47ac10b-58cc-4372-a567-0e02b2c3d479", + ] + assert all(isinstance(uuid, str) for uuid in serialized["uuid_array"]) + + +@pytest.mark.asyncio +async def test_serialize_mixed_special_types(): + """Test serialization of multiple special types together""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "ip_addr": IPv4Address("192.168.1.1"), + "ip_network": IPv4Network("10.0.0.0/8"), + "uuid_col": UUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), + "mixed_array": [ + IPv4Address("10.0.0.1"), + UUID("550e8400-e29b-41d4-a716-446655440000"), + ], + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["ip_addr"] == "192.168.1.1" + assert serialized["ip_network"] == "10.0.0.0/8" + assert serialized["uuid_col"] == "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11" + assert serialized["mixed_array"] == [ + "10.0.0.1", + "550e8400-e29b-41d4-a716-446655440000", + ] + + +@pytest.mark.asyncio +async def test_serialize_geometric_point(): + """Test serialization of PostgreSQL POINT type""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "location": Point(1.5, 2.5), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["location"] == "(1.5, 2.5)" + + +@pytest.mark.asyncio +async def test_serialize_geometric_line(): + """Test serialization of PostgreSQL LINE type""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "boundary": Line(1, -1, 0), # x - y = 0 + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["boundary"] == "(1, -1, 0)" + + +@pytest.mark.asyncio +async def test_serialize_geometric_lseg(): + """Test serialization of PostgreSQL LSEG (line segment) type""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "segment": LineSegment((0, 0), (1, 1)), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["segment"] == "[(0.0, 0.0), (1.0, 1.0)]" + + +@pytest.mark.asyncio +async def test_serialize_geometric_box(): + """Test serialization of PostgreSQL BOX type""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "area": Box((2, 2), (0, 0)), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["area"] == "[(2.0, 2.0), (0.0, 0.0)]" + + +@pytest.mark.asyncio +async def test_serialize_geometric_path(): + """Test serialization of PostgreSQL PATH type""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "route": Path((0, 0), (1, 1), (2, 0)), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["route"] == "open [(0.0, 0.0), (1.0, 1.0), (2.0, 0.0)]" + + +@pytest.mark.asyncio +async def test_serialize_geometric_polygon(): + """Test serialization of PostgreSQL POLYGON type""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "shape": Polygon((0, 0), (1, 0), (1, 1), (0, 1)), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["shape"] == "[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]" + + +@pytest.mark.asyncio +async def test_serialize_geometric_circle(): + """Test serialization of PostgreSQL CIRCLE type""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "zone": Circle((0, 0), 5), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["zone"] == "((0, 0), 5)" + + +@pytest.mark.asyncio +async def test_serialize_bitstring(): + """Test serialization of PostgreSQL BIT/VARBIT types""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + "flags": BitString("10101010"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert + assert serialized["flags"] == "1010 1010" + + +@pytest.mark.asyncio +async def test_serialize_all_special_types(): + """Test serialization of all PostgreSQL special types together""" + # Setup + async with create_postgresql_source() as source: + doc = { + "id": 1, + # Network types + "ip_inet": IPv4Address("192.168.1.1"), + "ip_cidr": IPv4Network("10.0.0.0/8"), + # UUID + "uuid_col": UUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), + # Geometric types + "point_col": Point(1.5, 2.5), + "line_col": Line(1, -1, 0), + "lseg_col": LineSegment((0, 0), (1, 1)), + "box_col": Box((2, 2), (0, 0)), + "path_col": Path((0, 0), (1, 1), (2, 0)), + "polygon_col": Polygon((0, 0), (1, 0), (1, 1), (0, 1)), + "circle_col": Circle((0, 0), 5), + # BitString + "bit_col": BitString("10101010"), + } + + # Execute + serialized = source.serialize(doc) + + # Assert - check exact serialized values + assert serialized["ip_inet"] == "192.168.1.1" + assert serialized["ip_cidr"] == "10.0.0.0/8" + assert serialized["uuid_col"] == "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11" + assert serialized["point_col"] == "(1.5, 2.5)" + assert serialized["line_col"] == "(1, -1, 0)" + assert serialized["lseg_col"] == "[(0.0, 0.0), (1.0, 1.0)]" + assert serialized["box_col"] == "[(2.0, 2.0), (0.0, 0.0)]" + assert serialized["path_col"] == "open [(0.0, 0.0), (1.0, 1.0), (2.0, 0.0)]" + assert ( + serialized["polygon_col"] + == "[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]" + ) + assert serialized["circle_col"] == "((0, 0), 5)" + assert serialized["bit_col"] == "1010 1010"