Skip to content

Commit 6b09d3b

Browse files
committed
some more work on scheduler with relations
1 parent ac5089c commit 6b09d3b

File tree

5 files changed

+89
-30
lines changed

5 files changed

+89
-30
lines changed

cloudquery/sdk/internal/servers/plugin_v3/plugin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pyarrow as pa
2+
import structlog
23

34
from cloudquery.plugin_v3 import plugin_pb2, plugin_pb2_grpc
45
from cloudquery.sdk.message import SyncInsertMessage, SyncMigrateTableMessage
@@ -7,7 +8,8 @@
78

89

910
class PluginServicer(plugin_pb2_grpc.PluginServicer):
10-
def __init__(self, plugin: Plugin):
11+
def __init__(self, plugin: Plugin, logger=None):
12+
self._logger = logger if logger is not None else structlog.get_logger()
1113
self._plugin = plugin
1214

1315
def GetName(self, request, context):

cloudquery/sdk/scheduler/scheduler.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Generator, Any
33
import queue
44
import time
5+
import structlog
56
from enum import Enum
67
from cloudquery.sdk.schema import Table, Resource
78
from cloudquery.sdk.message import SyncMessage, SyncInsertMessage, SyncMigrateTableMessage
@@ -12,30 +13,33 @@
1213

1314
QUEUE_PER_WORKER = 100
1415

16+
1517
class ThreadPoolExecutorWithQueueSizeLimit(futures.ThreadPoolExecutor):
1618
def __init__(self, maxsize, *args, **kwargs):
1719
super(ThreadPoolExecutorWithQueueSizeLimit, self).__init__(*args, **kwargs)
1820
self._work_queue = queue.Queue(maxsize=maxsize)
1921

2022

21-
class WorkerStatus:
22-
def __init__(self, total_table_resolvers) -> None:
23-
self._total_table_resolvers = total_table_resolvers
23+
class TableResolverStarted:
24+
def __init__(self, count=1) -> None:
25+
self._count = count
2426

2527
@property
26-
def total_table_resolvers(self):
27-
return self._total_table_resolvers
28+
def count(self):
29+
return self._count
2830

2931

30-
class TableResolverStatus:
32+
class TableResolverFinished:
3133
def __init__(self) -> None:
3234
pass
3335

3436

3537
class Scheduler:
36-
def __init__(self, concurrency: int, queue_size: int = 0, max_depth : int = 3):
38+
def __init__(self, concurrency: int, queue_size: int = 0, max_depth : int = 3, logger=None):
3739
self._queue = queue.Queue()
3840
self._max_depth = max_depth
41+
if logger is None:
42+
self._logger = structlog.get_logger()
3943
if concurrency <= 0:
4044
raise ValueError("concurrency must be greater than 0")
4145
if max_depth <= 0:
@@ -49,34 +53,53 @@ def __init__(self, concurrency: int, queue_size: int = 0, max_depth : int = 3):
4953
current_depth_concurrency = current_depth_concurrency // 2 if current_depth_concurrency > 1 else 1
5054
current_depth_queue_size = current_depth_queue_size // 2 if current_depth_queue_size > 1 else 1
5155

56+
def shutdown(self):
57+
for pool in self._pools:
58+
pool.shutdown()
59+
5260
def resolve_resource(self, resolver: TableResolver, client, parent: Resource, item: Any) -> Resource:
53-
resource = Resource(resolver.table, None, item)
61+
resource = Resource(resolver.table, parent, item)
5462
resolver.pre_resource_resolve(client, resource)
5563
for column in resolver.table.columns:
5664
resolver.resolve_column(client, resource, column.name)
5765
resolver.post_resource_resolve(client, resource)
5866
return resource
5967

60-
def resolve_table(self, resolver: TableResolver, client, parent_item: Any, res: queue.Queue):
68+
def resolve_table(self, resolver: TableResolver, depth: int, client, parent_item: Resource, res: queue.Queue):
69+
table_resolvers_started = 0
6170
try:
71+
if depth == 0:
72+
self._logger.info("table resolver started", table=resolver.table.name, depth=depth)
73+
else:
74+
self._logger.debug("table resolver started", table=resolver.table.name, depth=depth)
75+
total_resources = 0
6276
for item in resolver.resolve(client, parent_item):
6377
resource = self.resolve_resource(resolver, client, parent_item, item)
6478
res.put(SyncInsertMessage(resource.to_arrow_record()))
79+
for child_resolvers in resolver.child_resolvers:
80+
self._pools[depth + 1].submit(self.resolve_table, child_resolvers, depth + 1, client, resource, res)
81+
table_resolvers_started += 1
82+
total_resources += 1
83+
if depth == 0:
84+
self._logger.info("table resolver finished successfully", table=resolver.table.name, depth=depth)
85+
else:
86+
self._logger.debug("table resolver finished successfully", table=resolver.table.name, depth=depth)
6587
except Exception as e:
66-
traceback.print_exc()
67-
print("exception")
68-
print(e)
88+
self._logger.error("table resolver finished with error", table=resolver.table.name, depth=depth, exception=e)
6989
finally:
70-
res.put(TableResolverStatus())
90+
res.put(TableResolverStarted(count=table_resolvers_started))
91+
res.put(TableResolverFinished())
7192

7293
def _sync(self, client, resolvers: List[TableResolver], res: queue.Queue, deterministic_cq_id=False):
7394
total_table_resolvers = 0
74-
for resolver in resolvers:
75-
clients = resolver.multiplex(client)
76-
for client in clients:
77-
self._pools[0].submit(self.resolve_table, resolver, client, None, res)
78-
total_table_resolvers += 1
79-
res.put(WorkerStatus(total_table_resolvers))
95+
try:
96+
for resolver in resolvers:
97+
clients = resolver.multiplex(client)
98+
for client in clients:
99+
self._pools[0].submit(self.resolve_table, resolver, 0, client, None, res)
100+
total_table_resolvers += 1
101+
finally:
102+
res.put(TableResolverStarted(total_table_resolvers))
80103

81104
def sync(self, client, resolvers: List[TableResolver], deterministic_cq_id=False) -> Generator[SyncMessage, None, None]:
82105
res = queue.Queue()
@@ -88,12 +111,12 @@ def sync(self, client, resolvers: List[TableResolver], deterministic_cq_id=False
88111
finished_table_resovlers = 0
89112
while True:
90113
message = res.get()
91-
if type(message) == WorkerStatus:
92-
total_table_resolvers += message.total_table_resolvers
114+
if type(message) == TableResolverStarted:
115+
total_table_resolvers += message.count
93116
if total_table_resolvers == finished_table_resovlers:
94117
break
95118
continue
96-
elif type(message) == TableResolverStatus:
119+
elif type(message) == TableResolverFinished:
97120
finished_table_resovlers += 1
98121
if total_table_resolvers == finished_table_resovlers:
99122
break

cloudquery/sdk/scheduler/table_resolver.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@
44
from typing import Any,Generator
55

66
class TableResolver:
7-
def __init__(self, table: Table) -> None:
7+
def __init__(self, table: Table, child_resolvers=[]) -> None:
88
self._table = table
9+
self._child_resolvers = child_resolvers
910

1011
@property
1112
def table(self) -> Table:
1213
return self._table
1314

15+
@property
16+
def child_resolvers(self):
17+
return self._child_resolvers
18+
1419
def multiplex(self, client):
1520
return [client]
1621

@@ -21,7 +26,11 @@ def pre_resource_resolve(self, client, resource):
2126
return
2227

2328
def resolve_column(self, client, resource: Resource, column_name: str):
24-
if hasattr(resource.item, column_name):
29+
if type(resource.item) is dict:
30+
if column_name in resource.item:
31+
resource.set(column_name, resource.item[column_name])
32+
else:
33+
if hasattr(resource.item, column_name):
2534
resource.set(column_name, resource.item[column_name])
2635

2736
def post_resource_resolve(self, client, resource):

cloudquery/sdk/serve/plugin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import structlog
23
from concurrent import futures
34

45
import grpc
@@ -14,6 +15,10 @@
1415
DOC_FORMATS = ["json", "markdown"]
1516

1617

18+
def get_logger(args):
19+
log = structlog.get_logger()
20+
return log
21+
1722
class PluginCommand:
1823
def __init__(self, plugin: Plugin):
1924
self._plugin = plugin
@@ -57,11 +62,12 @@ def run(self, args):
5762
sys.exit(1)
5863

5964
def _serve(self, args):
65+
logger = get_logger(args)
6066
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
6167
discovery_pb2_grpc.add_DiscoveryServicer_to_server(
6268
DiscoveryServicer([3]), server)
6369
plugin_pb2_grpc.add_PluginServicer_to_server(
64-
PluginServicer(self._plugin), server)
70+
PluginServicer(self._plugin, logger), server)
6571
server.add_insecure_port(args.address)
6672
print("Starting server. Listening on " + args.address)
6773
server.start()

tests/scheduler/scheduler.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
from typing import Any, List, Generator
33
import pyarrow as pa
4+
import pytest
45
from cloudquery.sdk.scheduler import Scheduler, TableResolver
56
from cloudquery.sdk.schema import Table, Column, Resource
67
from cloudquery.sdk.message import SyncMessage
@@ -12,22 +13,40 @@ def __init__(self):
1213
Column("test_column", pa.int64())
1314
])
1415

16+
class SchedulerTestChildTable(Table):
17+
def __init__(self):
18+
super().__init__("test_child_table", [
19+
Column("test_child_column", pa.int64())
20+
])
21+
1522
class SchedulerTestTableResolver(TableResolver):
1623
def __init__(self) -> None:
17-
super().__init__(SchedulerTestTable())
24+
super().__init__(SchedulerTestTable(), child_resolvers=[SchedulerTestChildTableResolver()])
1825

1926
def resolve(self, client, parent_resource) -> Generator[Any, None, None]:
2027
yield {"test_column": 1}
2128

29+
class SchedulerTestChildTableResolver(TableResolver):
30+
def __init__(self) -> None:
31+
super().__init__(SchedulerTestChildTable())
32+
33+
def resolve(self, client, parent_resource) -> Generator[Any, None, None]:
34+
yield {"test_child_column": 2}
35+
2236
class TestClient:
2337
pass
2438

2539
def test_scheduler():
2640
client = TestClient()
2741
s = Scheduler(10)
42+
table1 = Table("test_table", [Column("test_column", pa.int64())])
43+
expected_record1 = pa.record_batch([[1]], schema=table1.to_arrow_schema())
44+
table2 = Table("test_child_table", [Column("test_child_column", pa.int64())])
45+
expected_record2 = pa.record_batch([[2]], schema=table2.to_arrow_schema())
2846
resources: List[SyncMessage] = []
2947
for resource in s.sync(client, [SchedulerTestTableResolver()]):
3048
resources.append(resource)
31-
assert len(resources) == 2
32-
print(resources[1].record.to_pydict())
33-
49+
assert len(resources) == 3
50+
assert resources[1].record == expected_record1
51+
assert resources[2].record == expected_record2
52+
s.shutdown()

0 commit comments

Comments
 (0)