Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 168 additions & 67 deletions cloudquery/sdk/internal/memdb/memdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,116 +3,217 @@
from cloudquery.sdk import plugin
from cloudquery.sdk import message
from cloudquery.sdk import schema
from typing import List, Generator, Dict
from cloudquery.sdk.scheduler import Scheduler, TableResolver
from typing import List, Generator, Dict, Any
import pyarrow as pa
from cloudquery.sdk.schema.table import Table
from cloudquery.sdk.schema.arrow import METADATA_TABLE_NAME
from cloudquery.sdk.types import JSONType
from dataclasses import dataclass, field

NAME = "memdb"
VERSION = "development"


class Client:
def __init__(self) -> None:
pass

def id(self):
return "memdb"


class MemDBResolver(TableResolver):
def __init__(
self, table: Table, records: List, child_resolvers: list[TableResolver] = None
) -> None:
super().__init__(table=table, child_resolvers=child_resolvers)
self._records = records

def resolve(self, client: None, parent_resource) -> Generator[Any, None, None]:
for record in self._records:
yield record


class Table1Relation1(Table):
def __init__(self) -> None:
super().__init__(
name="table_1_relation_1",
columns=[
schema.Column(
name="name",
type=pa.string(),
primary_key=True,
not_null=True,
unique=True,
),
schema.Column(name="data", type=JSONType()),
],
title="Table 1 Relation 1",
description="Test Table 1 Relation 1",
)

@property
def resolver(self):
return MemDBResolver(
self,
records=[
{"name": "a", "data": {"a": 1}},
{"name": "b", "data": {"b": 2}},
{"name": "c", "data": {"c": 3}},
],
)


class Table1(Table):
def __init__(self) -> None:
super().__init__(
name="table_1",
columns=[
schema.Column(
name="name",
type=pa.string(),
primary_key=True,
not_null=True,
unique=True,
),
schema.Column(
name="id",
type=pa.int64(),
primary_key=True,
not_null=True,
unique=True,
incremental_key=True,
),
],
title="Table 1",
description="Test Table 1",
is_incremental=True,
relations=[Table1Relation1()],
)

@property
def resolver(self):
child_resolvers: list[TableResolver] = []
for rel in self.relations:
child_resolvers.append(rel.resolver)

return MemDBResolver(
self,
records=[
{"name": "a", "id": 1},
{"name": "b", "id": 2},
{"name": "c", "id": 3},
],
child_resolvers=child_resolvers,
)


class Table2(Table):
def __init__(self) -> None:
super().__init__(
name="table_2",
columns=[
schema.Column(
name="name",
type=pa.string(),
primary_key=True,
not_null=True,
unique=True,
),
schema.Column(name="id", type=pa.int64()),
],
title="Table 2",
description="Test Table 2",
)

@property
def resolver(self):
return MemDBResolver(
self,
records=[
{"name": "a", "id": 1},
{"name": "b", "id": 2},
{"name": "c", "id": 3},
],
)


@dataclass
class Spec:
abc: str = field(default="abc")
concurrency: int = field(default=1000)
queue_size: int = field(default=1000)


class MemDB(plugin.Plugin):
def __init__(self) -> None:
super().__init__(
NAME, VERSION, opts=plugin.plugin.Options(team="cloudquery", kind="source")
)
self._db: Dict[str, pa.RecordBatch] = {}
table1 = Table1()
table2 = Table2()
self._tables: Dict[str, schema.Table] = {
"table_1": schema.Table(
name="table_1",
columns=[
schema.Column(
name="name",
type=pa.string(),
primary_key=True,
not_null=True,
unique=True,
),
schema.Column(
name="id",
type=pa.string(),
primary_key=True,
not_null=True,
unique=True,
incremental_key=True,
),
],
title="Table 1",
description="Test Table 1",
is_incremental=True,
relations=[
schema.Table(
name="table_1_relation_1",
columns=[
schema.Column(
name="name",
type=pa.string(),
primary_key=True,
not_null=True,
unique=True,
),
schema.Column(name="data", type=JSONType()),
],
title="Table 1 Relation 1",
description="Test Table 1 Relation 1",
)
],
),
"table_2": schema.Table(
name="table_2",
columns=[
schema.Column(
name="name",
type=pa.string(),
primary_key=True,
not_null=True,
unique=True,
),
schema.Column(name="id", type=pa.string()),
],
title="Table 2",
description="Test Table 2",
),
table1.name: table1,
table2.name: table2,
}
self._db: List[pa.RecordBatch] = []
self._client = Client()

def set_logger(self, logger) -> None:
self._logger = logger

def init(self, spec, no_connection: bool = False):
if no_connection:
return
self._spec_json = json.loads(spec)
self._spec = Spec(**self._spec_json)
self._scheduler = Scheduler(
concurrency=self._spec.concurrency,
queue_size=self._spec.queue_size,
logger=self._logger,
)

def get_tables(self, options: plugin.TableOptions = None) -> List[plugin.Table]:
tables = list(self._tables.values())

# set parent table relationships
for table in tables:
for relation in table.relations:
relation.parent = table

return schema.filter_dfs(tables, options.tables, options.skip_tables)

def sync(
self, options: plugin.SyncOptions
) -> Generator[message.SyncMessage, None, None]:
for table, record in self._db.items():
yield message.SyncInsertMessage(record)
resolvers: list[TableResolver] = []
for table in self.get_tables(
plugin.TableOptions(
tables=options.tables,
skip_tables=options.skip_tables,
skip_dependent_tables=options.skip_dependent_tables,
)
):
resolvers.append(table.resolver)

return self._scheduler.sync(
self._client, resolvers, options.deterministic_cq_id
)

def write(self, writer: Generator[message.WriteMessage, None, None]) -> None:
for msg in writer:
if isinstance(msg, message.WriteMigrateTableMessage):
if msg.table.name not in self._db:
self._db[msg.table.name] = msg.table
self._tables[msg.table.name] = msg.table
pass
elif isinstance(msg, message.WriteInsertMessage):
table = schema.Table.from_arrow_schema(msg.record.schema)
self._db[table.name] = msg.record
self._db.append(msg.record)
else:
raise NotImplementedError(f"Unknown message type {type(msg)}")

def read(self, table: Table) -> Generator[message.ReadMessage, None, None]:
for table, record in self._db.items():
yield message.ReadMessage(record)
for record in self._db:
recordMetadata = record.schema.metadata.get(METADATA_TABLE_NAME).decode("utf-8")
if recordMetadata == table.name:
yield message.ReadMessage(record)

def close(self) -> None:
self._db = {}
7 changes: 7 additions & 0 deletions cloudquery/sdk/internal/servers/plugin_v3/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cloudquery.sdk.message import (
SyncInsertMessage,
SyncMigrateTableMessage,
SyncErrorMessage,
WriteInsertMessage,
WriteMigrateTableMessage,
WriteMessage,
Expand Down Expand Up @@ -77,6 +78,12 @@ def Sync(self, request, context):
yield plugin_pb2.Sync.Response(
migrate_table=plugin_pb2.Sync.MessageMigrateTable(table=buf)
)
elif isinstance(msg, SyncErrorMessage) and request.withErrorMessages:
Copy link
Member Author

@erezrokah erezrokah Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change (part 1) to add the feature

yield plugin_pb2.Sync.Response(
error=plugin_pb2.Sync.MessageError(
table_name=msg.table_name, error=msg.error
)
)
else:
# unknown sync message type
raise NotImplementedError()
Expand Down
7 changes: 6 additions & 1 deletion cloudquery/sdk/message/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .sync import SyncMessage, SyncInsertMessage, SyncMigrateTableMessage
from .sync import (
SyncMessage,
SyncInsertMessage,
SyncMigrateTableMessage,
SyncErrorMessage,
)
from .write import (
WriteMessage,
WriteInsertMessage,
Expand Down
6 changes: 6 additions & 0 deletions cloudquery/sdk/message/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ def __init__(self, record: pa.RecordBatch):
class SyncMigrateTableMessage(SyncMessage):
def __init__(self, table: pa.Schema):
self.table = table


class SyncErrorMessage(SyncMessage):
def __init__(self, table_name: str, error: str):
self.table_name = table_name
self.error = error
2 changes: 2 additions & 0 deletions cloudquery/sdk/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SyncMessage,
SyncInsertMessage,
SyncMigrateTableMessage,
SyncErrorMessage,
)
from cloudquery.sdk.schema import Resource
from cloudquery.sdk.stateclient.stateclient import StateClient
Expand Down Expand Up @@ -162,6 +163,7 @@ def resolve_table(
depth=depth,
exc_info=e,
)
res.put(SyncErrorMessage(resolver.table.name, str(e)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change (part 2) to add the feature

finally:
res.put(TableResolverFinished())

Expand Down
17 changes: 15 additions & 2 deletions tests/internal/memdb/memdb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
from cloudquery.sdk.internal import memdb
from cloudquery.sdk.internal.servers.plugin_v3 import plugin
from cloudquery.sdk.plugin import SyncOptions
from cloudquery.sdk.message import SyncMigrateTableMessage, SyncInsertMessage
import structlog


def test_memdb():
p = memdb.MemDB()
p.set_logger(structlog.get_logger())
p.init(plugin.sanitize_spec(b"null"))
msgs = []
for msg in p.sync(SyncOptions(tables=["*"])):
for msg in p.sync(SyncOptions(tables=["*"],skip_tables=[])):
msgs.append(msg)
assert len(msgs) == 0
assert len(msgs) == 18

assert isinstance(msgs[0], SyncMigrateTableMessage)
assert isinstance(msgs[1], SyncMigrateTableMessage)
assert isinstance(msgs[2], SyncMigrateTableMessage)

# other messages should be inserts
for msg in msgs[3:]:
assert isinstance(msg, SyncInsertMessage)


Loading
Loading