Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions connectors/sources/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,29 @@

import ssl
from functools import cached_property, partial
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
)
from urllib.parse import quote
from uuid import UUID

import fastjsonschema
from asyncpg.exceptions._base import InternalClientError
from asyncpg.types import (
BitString,
Box,
Circle,
Line,
LineSegment,
Path,
Point,
Polygon,
)
from fastjsonschema import JsonSchemaValueException
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError
Expand Down Expand Up @@ -510,6 +529,76 @@ async def ping(self):
msg = f"Can't connect to Postgresql on {self.postgresql_client.host}."
raise Exception(msg) from e

def serialize(self, doc):
"""Override base serialize to handle PostgreSQL-specific types.

PostgreSQL connector uses asyncpg which returns special Python objects for certain
PostgreSQL data types that need to be serialized to strings:
- Network types (INET, CIDR) -> ipaddress module objects
- UUID type -> uuid.UUID objects
- Geometric types (POINT, LINE, POLYGON, etc.) -> asyncpg.types objects
- BitString type (BIT, VARBIT) -> asyncpg.types.BitString objects

Args:
doc (Dict): Dictionary to be serialized

Returns:
doc (Dict): Serialized version of dictionary
"""

def _serialize(value):
"""Serialize input value with respect to its datatype.

Args:
value (Any): Value to be serialized

Returns:
value (Any): Serialized version of input value.
"""
match value:
case (
IPv4Address()
| IPv6Address()
| IPv4Interface()
| IPv6Interface()
| IPv4Network()
| IPv6Network()
):
return str(value)
case UUID():
return str(value)
case Point():
return f"({value.x}, {value.y})"
case LineSegment():
return (
f"[({value.p1.x}, {value.p1.y}), ({value.p2.x}, {value.p2.y})]"
)
case Box():
return f"[({value.high.x}, {value.high.y}), ({value.low.x}, {value.low.y})]"
case Polygon():
# Polygon inherits from Path, so check it first
coords = [(p.x, p.y) for p in value.points]
return str(coords)
case Path():
coords = [(p.x, p.y) for p in value.points]
status = "closed" if value.is_closed else "open"
return f"{status} {str(coords)}"
case Line() | Circle():
return str(value)
case BitString():
return value.as_string()
case list() | tuple():
return [_serialize(item) for item in value]
case dict():
return {k: _serialize(v) for k, v in value.items()}
case _:
return value

for key, value in doc.items():
doc[key] = _serialize(value)

return super().serialize(doc)

def row2doc(self, row, doc_id, table, timestamp):
row.update(
{
Expand Down
114 changes: 113 additions & 1 deletion tests/sources/fixtures/postgresql/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
import random

import asyncpg
from asyncpg.types import (
BitString,
Box,
Circle,
Line,
LineSegment,
Path,
Point,
Polygon,
)

from tests.commons import WeightedFakeProvider

Expand Down Expand Up @@ -36,9 +46,12 @@

event_loop = asyncio.get_event_loop()

# Number of test rows in special_types table for serialization testing
SPECIAL_TYPES_TEST_ROWS = 3


def get_num_docs():
print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE))
print(NUM_TABLES * (RECORD_COUNT - RECORDS_TO_DELETE) + SPECIAL_TYPES_TEST_ROWS)


async def load():
Expand Down Expand Up @@ -81,6 +94,104 @@ async def inject_lines(table, connect, lines):
inserted += batch_size
print(f"Inserting batch #{batch} of {batch_size} documents.")

async def create_special_types_table():
"""Create a table with PostgreSQL special types that require serialization."""
connect = await asyncpg.connect(CONNECTION_STRING)

print("Creating special_types table for serialization testing...")
create_table_sql = """
CREATE TABLE IF NOT EXISTS special_types (
id SERIAL PRIMARY KEY,
ip_inet INET,
ip_cidr CIDR,
uuid_col UUID,
point_col POINT,
line_col LINE,
lseg_col LSEG,
box_col BOX,
path_col PATH,
polygon_col POLYGON,
circle_col CIRCLE,
bit_col BIT(8),
varbit_col VARBIT(16),
inet_array INET[],
uuid_array UUID[]
)
"""
await connect.execute(create_table_sql)

print("Inserting special type test data...")
insert_sql = """
INSERT INTO special_types (
ip_inet, ip_cidr, uuid_col,
point_col, line_col, lseg_col, box_col, path_col, polygon_col, circle_col,
bit_col, varbit_col,
inet_array, uuid_array
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
"""

test_rows = [
(
"192.168.1.1",
"10.0.0.0/8",
"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11",
Point(1.5, 2.5),
Line(1, -1, 0),
LineSegment((0, 0), (1, 1)),
Box((2, 2), (0, 0)),
Path((0, 0), (1, 1), (2, 0)),
Polygon((0, 0), (1, 0), (1, 1), (0, 1)),
Circle((0, 0), 5),
BitString("10101010"),
BitString("1101"),
["192.168.1.1", "10.0.0.1"],
[
"550e8400-e29b-41d4-a716-446655440000",
"f47ac10b-58cc-4372-a567-0e02b2c3d479",
],
),
(
"2001:db8::1",
"2001:db8::/32",
"123e4567-e89b-12d3-a456-426614174000",
Point(-3.14, 2.71),
Line(2, 3, -6),
LineSegment((-1, -1), (1, 1)),
Box((10, 10), (-10, -10)),
Path((0, 0), (3, 0), (3, 3), (0, 3), is_closed=True),
Polygon((-1, -1), (1, -1), (1, 1), (-1, 1)),
Circle((5, 5), 2.5),
BitString("11110000"),
BitString("101010101010"),
["::1", "fe80::1"],
["aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"],
),
(
"10.30.0.9/24",
"172.16.0.0/12",
"00000000-0000-0000-0000-000000000000",
Point(0, 0),
Line(0, 1, 0),
LineSegment((5, 5), (10, 10)),
Box((100, 100), (50, 50)),
Path((0, 0), (5, 0), (5, 5)),
Polygon((0, 0), (4, 0), (4, 3), (0, 3)),
Circle((0, 0), 10),
BitString("00000000"),
BitString("1111111111111111"),
["192.0.2.1", "198.51.100.1", "203.0.113.1"],
[
"12345678-1234-5678-1234-567812345678",
"87654321-4321-8765-4321-876543218765",
],
),
]

await connect.executemany(insert_sql, test_rows)
print(f"Inserted {len(test_rows)} rows with special types")

await connect.close()

async def load_rows():
"""N tables of 10001 rows each. each row is ~ 1024*20 bytes"""
connect = await asyncpg.connect(CONNECTION_STRING)
Expand All @@ -93,6 +204,7 @@ async def load_rows():

await create_readonly_user()
await load_rows()
await create_special_types_table()


async def remove():
Expand Down
Loading