Skip to content

Commit 3193879

Browse files
feat: State Client (#198)
* wio * Add gRPC connection example * Implement working StateClient. * Remove unnecessary modifications. * Remove unnecessary modifications. * Remove TODO. --------- Co-authored-by: Herman Schaaf <[email protected]>
1 parent a2835bb commit 3193879

File tree

5 files changed

+231
-1
lines changed

5 files changed

+231
-1
lines changed

cloudquery/sdk/internal/servers/plugin_v3/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def Sync(self, request, context):
6363
skip_dependent_tables=request.skip_dependent_tables,
6464
skip_tables=request.skip_tables,
6565
tables=request.tables,
66-
backend_options=None,
66+
backend_options=request.backend,
6767
)
6868

6969
for msg in self._plugin.sync(options):

cloudquery/sdk/scheduler/scheduler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SyncMigrateTableMessage,
1111
)
1212
from cloudquery.sdk.schema import Resource
13+
from cloudquery.sdk.stateclient.stateclient import StateClient
1314
from .table_resolver import TableResolver, Client
1415

1516
QUEUE_PER_WORKER = 100
@@ -35,6 +36,7 @@ class Scheduler:
3536
def __init__(
3637
self, concurrency: int, queue_size: int = 0, max_depth: int = 3, logger=None
3738
):
39+
self._post_sync_hook = lambda: None
3840
self._queue = queue.Queue()
3941
self._max_depth = max_depth
4042
if logger is None:
@@ -201,6 +203,8 @@ def sync(
201203
break
202204
continue
203205
yield message
206+
207+
self._post_sync_hook()
204208
thread.shutdown(wait=True)
205209

206210
def _send_migrate_table_messages(
@@ -210,3 +214,18 @@ def _send_migrate_table_messages(
210214
yield SyncMigrateTableMessage(table=resolver.table.to_arrow_schema())
211215
if resolver.child_resolvers:
212216
yield from self._send_migrate_table_messages(resolver.child_resolvers)
217+
218+
def set_post_sync_hook(self, fn):
219+
"""
220+
Use this to set a function that will be called after the sync is finished,
221+
a la `defer fn()` in Go (but for a single function, rather than a stack).
222+
223+
This is necessary because plugins use this pattern on their sync method:
224+
225+
```
226+
return self._scheduler.sync(...)
227+
```
228+
229+
So if a plugin has a `state_client`, there's nowhere to call the flush method.
230+
"""
231+
self._post_sync_hook = fn

cloudquery/sdk/scheduler/table_resolver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, table: Table, child_resolvers: Optional[List] = None) -> None
1414
child_resolvers = []
1515
self._table = table
1616
self._child_resolvers = child_resolvers
17+
self.state_client = None
1718

1819
@property
1920
def table(self) -> Table:
@@ -23,6 +24,9 @@ def table(self) -> Table:
2324
def child_resolvers(self):
2425
return self._child_resolvers
2526

27+
def set_state_client(self, state_client):
28+
self.state_client = state_client
29+
2630
def multiplex(self, client: Client) -> List[Client]:
2731
return [client]
2832

cloudquery/sdk/stateclient/__init__.py

Whitespace-only changes.
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from __future__ import annotations
2+
3+
import grpc
4+
from abc import ABC, abstractmethod
5+
from typing import Generator, Optional, Tuple
6+
7+
import pyarrow as pa
8+
from cloudquery.sdk import schema
9+
from cloudquery.sdk.plugin.plugin import BackendOptions
10+
from cloudquery.plugin_v3 import plugin_pb2, plugin_pb2_grpc, arrow
11+
from functools import wraps
12+
13+
keyColumn = "key"
14+
valueColumn = "value"
15+
16+
17+
class StateClientBuilder:
18+
"""
19+
Provides a `build` method that creates a `ConnectedStateClient` if you pass backend_options,
20+
or `NoOpStateClient` otherwise.
21+
22+
Args:
23+
backend_options (BackendOptions): which has `connection` & `table_name` strings.
24+
25+
Returns:
26+
`ConnectedStateClient` or `NoOpStateClient`
27+
28+
"""
29+
30+
@staticmethod
31+
def build(*, backend_options: BackendOptions) -> StateClient:
32+
if not backend_options or not backend_options.table_name:
33+
return NoOpStateClient(backend_options=backend_options)
34+
35+
return ConnectedStateClient(backend_options=backend_options)
36+
37+
38+
class StateClient(ABC):
39+
"""
40+
Abstract class that defines the interface for a state client.
41+
42+
It implements all methods except those that require a connection,
43+
so it is a succinct overview of what a state client does.
44+
"""
45+
46+
def __init__(self, *, backend_options: BackendOptions):
47+
self.mem = {}
48+
self.changes = {}
49+
self.connection = getattr(backend_options, "connection", None)
50+
self.table = Table(getattr(backend_options, "table_name", None))
51+
52+
self.migrate_state_table()
53+
self.read_all_state()
54+
55+
def get_key(self, key: str) -> Optional[str]:
56+
return self.mem.get(key)
57+
58+
def set_key(self, key: str, value: str) -> None:
59+
self.mem[key] = value
60+
self.changes[key] = True
61+
62+
def flush(self):
63+
if not self.changes:
64+
return
65+
66+
self.write_all_state(self._changed_keys())
67+
68+
def _changed_keys(self):
69+
for key, _ in self.changes.items():
70+
yield (key, self.mem[key])
71+
72+
@abstractmethod
73+
def migrate_state_table(self):
74+
pass
75+
76+
@abstractmethod
77+
def read_all_state(self):
78+
pass
79+
80+
@abstractmethod
81+
def write_all_state(self, changed_keys: Generator[Tuple[str, str], None, None]):
82+
pass
83+
84+
85+
class NoOpStateClient(StateClient):
86+
"""
87+
A state client implementation that does nothing. Used when there is no backend connection.
88+
"""
89+
90+
def get_key(self, key: str) -> Optional[str]:
91+
pass
92+
93+
def set_key(self, key: str, value: str) -> None:
94+
pass
95+
96+
def migrate_state_table(self):
97+
pass
98+
99+
def read_all_state(self):
100+
pass
101+
102+
def write_all_state(self, changed_keys: Generator[Tuple[str, str], None, None]):
103+
pass
104+
105+
106+
def connected(func):
107+
"""
108+
Decorator that provides a `backend_plugin` with a gRPC connection to the decorated function.
109+
"""
110+
111+
@wraps(func)
112+
def wrapper(self, *args, **kwargs):
113+
with grpc.insecure_channel(self.connection) as channel:
114+
backend_plugin = plugin_pb2_grpc.PluginStub(channel)
115+
return func(self, backend_plugin, *args, **kwargs)
116+
117+
return wrapper
118+
119+
120+
class ConnectedStateClient(StateClient):
121+
"""
122+
A state client implementation that connects to a backend plugin via gRPC to read/write state.
123+
"""
124+
125+
@connected
126+
def migrate_state_table(self, backend_plugin: plugin_pb2_grpc.PluginStub):
127+
backend_plugin.Write(self._migrate_table_request())
128+
129+
@connected
130+
def read_all_state(self, backend_plugin: plugin_pb2_grpc.PluginStub):
131+
response = backend_plugin.Read(
132+
plugin_pb2.Read.Request(table=self.table.bytes())
133+
)
134+
135+
for record in read_response_to_records(response):
136+
self.mem[record[keyColumn]] = record[valueColumn]
137+
138+
@connected
139+
def write_all_state(
140+
self,
141+
backend_plugin: plugin_pb2_grpc.PluginStub,
142+
changed_keys: Generator[Tuple[str, str], None, None],
143+
):
144+
backend_plugin.Write(self._write_request(k, v) for k, v in changed_keys)
145+
146+
def _write_request(self, key: str, value: str) -> plugin_pb2.Write.Request:
147+
record = pa.RecordBatch.from_arrays(
148+
[
149+
pa.array([key]),
150+
pa.array([value]),
151+
],
152+
schema=self.table.arrow_schema(),
153+
)
154+
return plugin_pb2.Write.Request(
155+
insert=plugin_pb2.Write.MessageInsert(record=arrow.record_to_bytes(record))
156+
)
157+
158+
def _migrate_table_request(self):
159+
yield plugin_pb2.Write.Request(
160+
migrate_table=plugin_pb2.Write.MessageMigrateTable(table=self.table.bytes())
161+
)
162+
163+
164+
def read_response_to_records(response) -> Generator[dict[str, str], None, None]:
165+
for record in response:
166+
record_batch = arrow.new_record_from_bytes(record.record)
167+
for record in recordbatch_to_list_of_maps(record_batch):
168+
yield record
169+
170+
171+
def recordbatch_to_list_of_maps(
172+
record_batch: pa.RecordBatch,
173+
) -> Generator[dict[str, str], None, None]:
174+
table = pa.Table.from_batches([record_batch])
175+
for row in table.to_pandas().to_dict(orient="records"):
176+
yield row
177+
178+
179+
class Table:
180+
"""
181+
Represents a state table with two columns: key and value.
182+
Provides convenience methods for whatever the gRPC requests need.
183+
"""
184+
185+
def __init__(self, name):
186+
self.name = name
187+
self._arrow_schema = None
188+
self._bytes = None
189+
190+
if self.name:
191+
self._arrow_schema = self._state_table_schema().to_arrow_schema()
192+
self._bytes = arrow.schema_to_bytes(self._arrow_schema)
193+
194+
def arrow_schema(self):
195+
return self._arrow_schema
196+
197+
def bytes(self):
198+
return self._bytes
199+
200+
def _state_table_schema(self):
201+
return schema.Table(
202+
name=self.name,
203+
columns=[
204+
schema.Column(name=keyColumn, type=pa.string(), primary_key=True),
205+
schema.Column(name=valueColumn, type=pa.string()),
206+
],
207+
)

0 commit comments

Comments
 (0)