Skip to content

Commit fb00584

Browse files
committed
test: Fix tests
1 parent d30be5c commit fb00584

File tree

3 files changed

+44
-28
lines changed

3 files changed

+44
-28
lines changed

cloudquery/sdk/internal/memdb/memdb.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import List, Generator, Dict, Any
88
import pyarrow as pa
99
from cloudquery.sdk.schema.table import Table
10+
from cloudquery.sdk.schema.arrow import METADATA_TABLE_NAME
1011
from cloudquery.sdk.types import JSONType
1112
from dataclasses import dataclass, field
1213

@@ -155,7 +156,7 @@ def __init__(self) -> None:
155156
table1.name: table1,
156157
table2.name: table2,
157158
}
158-
self._db: Dict[str, pa.RecordBatch] = {}
159+
self._db: List[pa.RecordBatch] = []
159160
self._client = Client()
160161

161162
def set_logger(self, logger) -> None:
@@ -202,18 +203,15 @@ def sync(
202203
def write(self, writer: Generator[message.WriteMessage, None, None]) -> None:
203204
for msg in writer:
204205
if isinstance(msg, message.WriteMigrateTableMessage):
205-
if msg.table.name not in self._db:
206-
self._db[msg.table.name] = msg.table
207-
self._tables[msg.table.name] = msg.table
206+
pass
208207
elif isinstance(msg, message.WriteInsertMessage):
209-
table = schema.Table.from_arrow_schema(msg.record.schema)
210-
self._db[table.name] = msg.record
208+
self._db.append(msg.record)
211209
else:
212210
raise NotImplementedError(f"Unknown message type {type(msg)}")
213211

214212
def read(self, table: Table) -> Generator[message.ReadMessage, None, None]:
215-
for table, record in self._db.items():
216-
recordMetadata = record.schema.metadata.get(schema.MetadataTableName)
213+
for record in self._db:
214+
recordMetadata = record.schema.metadata.get(METADATA_TABLE_NAME).decode("utf-8")
217215
if recordMetadata == table.name:
218216
yield message.ReadMessage(record)
219217

tests/internal/memdb/memdb.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
11
from cloudquery.sdk.internal import memdb
22
from cloudquery.sdk.internal.servers.plugin_v3 import plugin
33
from cloudquery.sdk.plugin import SyncOptions
4+
from cloudquery.sdk.message import SyncMigrateTableMessage, SyncInsertMessage
5+
import structlog
46

57

68
def test_memdb():
79
p = memdb.MemDB()
10+
p.set_logger(structlog.get_logger())
811
p.init(plugin.sanitize_spec(b"null"))
912
msgs = []
10-
for msg in p.sync(SyncOptions(tables=["*"])):
13+
for msg in p.sync(SyncOptions(tables=["*"],skip_tables=[])):
1114
msgs.append(msg)
12-
assert len(msgs) == 0
15+
assert len(msgs) == 18
16+
17+
assert isinstance(msgs[0], SyncMigrateTableMessage)
18+
assert isinstance(msgs[1], SyncMigrateTableMessage)
19+
assert isinstance(msgs[2], SyncMigrateTableMessage)
20+
21+
# other messages should be inserts
22+
for msg in msgs[3:]:
23+
assert isinstance(msg, SyncInsertMessage)
24+
25+

tests/serve/plugin.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@
1414
from cloudquery.sdk.types.uuid import UUIDType
1515

1616
test_table = Table(
17-
"test",
18-
[
19-
Column("id", pa.int64()),
20-
Column("name", pa.string()),
21-
Column("json", JSONType()),
22-
Column("uuid", UUIDType()),
23-
],
24-
)
25-
17+
"test_table",
18+
[Column("id", pa.int64()), Column("name", pa.string()), Column("json", JSONType()), Column("uuid", UUIDType())],
19+
)
2620

2721
def test_plugin_serve():
2822
p = MemDB()
@@ -72,17 +66,29 @@ def writer_iterator():
7266

7367
stub.Write(writer_iterator())
7468

75-
response = stub.GetTables(plugin_pb2.GetTables.Request(tables=["*"]))
69+
response = stub.GetTables(plugin_pb2.GetTables.Request(tables=["*"],skip_tables=[]))
7670
schemas = arrow.new_schemas_from_bytes(response.tables)
77-
assert len(schemas) == 4
71+
assert len(schemas) == 3
7872

79-
response = stub.Sync(plugin_pb2.Sync.Request(tables=["*"]))
73+
response = stub.Sync(plugin_pb2.Sync.Request(tables=["*"],skip_tables=[]))
74+
total_migrate_tables = 0
8075
total_records = 0
76+
total_errors = 0
8177
for msg in response:
82-
if msg.insert is not None:
78+
message_type = msg.WhichOneof("message")
79+
if message_type == "insert":
8380
rec = arrow.new_record_from_bytes(msg.insert.record)
81+
assert rec.num_rows > 0
8482
total_records += 1
85-
assert total_records == 1
83+
elif message_type == "migrate_table":
84+
total_migrate_tables += 1
85+
elif message_type == "error":
86+
total_errors += 1
87+
else:
88+
raise NotImplementedError(f"Unknown message type {type(msg)}")
89+
assert total_migrate_tables == 3
90+
assert total_records == 15
91+
assert total_errors == 0
8692
finally:
8793
cmd.stop()
8894
pool.shutdown()
@@ -122,8 +128,7 @@ def test_plugin_read():
122128
],
123129
schema=test_table.to_arrow_schema(),
124130
)
125-
p._db["test_1"] = sample_record_1
126-
p._db["test_2"] = sample_record_2
131+
p._db = [sample_record_1, sample_record_2]
127132

128133
cmd = serve.PluginCommand(p)
129134
port = random.randint(5000, 50000)
@@ -191,7 +196,7 @@ def test_plugin_package():
191196
},
192197
{
193198
"name": "id",
194-
"type": "string",
199+
"type": "int64",
195200
"description": "",
196201
"incremental_key": True,
197202
"primary_key": True,
@@ -247,7 +252,7 @@ def test_plugin_package():
247252
},
248253
{
249254
"name": "id",
250-
"type": "string",
255+
"type": "int64",
251256
"description": "",
252257
"incremental_key": False,
253258
"primary_key": False,

0 commit comments

Comments
 (0)