Skip to content

Commit da31660

Browse files
committed
wip scheduler
1 parent 9bc72c7 commit da31660

File tree

18 files changed

+339
-89
lines changed

18 files changed

+339
-89
lines changed

cloudquery/sdk/scalar/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .scalar import Scalar, ScalarInvalidTypeError, ScalarFactory
2+
from .binary import Binary
3+
from .bool import Bool
4+
from .date32 import Date32
5+
from .float64 import Float64
6+
from .int64 import Int64

cloudquery/sdk/scalar/binary.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError
2+
3+
class Binary(Scalar):
4+
def __init__(self, valid: bool = False, value: bytes = None):
5+
self._valid = valid
6+
self._value = value
7+
8+
def __eq__(self, scalar: Scalar) -> bool:
9+
if scalar is None:
10+
return False
11+
if type(scalar) == Binary:
12+
return self._value == scalar._value and self._valid == scalar._valid
13+
return False
14+
15+
def set(self, scalar):
16+
if scalar is None:
17+
return
18+
19+
if type(scalar) == bytes:
20+
self._valid = True
21+
self._value = scalar
22+
elif type(scalar) == str:
23+
self._valid = True
24+
self._value = scalar.encode()
25+
else:
26+
raise ScalarInvalidTypeError("Invalid type for Binary scalar")

cloudquery/sdk/scalar/bool.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
2+
from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError
3+
from typing import Any
4+
5+
def parse_string_to_bool(input_string):
6+
true_strings = ['true', 't', 'yes', 'y', '1']
7+
false_strings = ['false', 'f', 'no', 'n', '0']
8+
9+
lower_input = input_string.lower()
10+
11+
if lower_input in true_strings:
12+
return True
13+
elif lower_input in false_strings:
14+
return False
15+
else:
16+
raise ScalarInvalidTypeError("Invalid boolean string: {}".format(input_string))
17+
18+
class Bool(Scalar):
19+
def __init__(self, valid: bool = False, value: bool = False) -> None:
20+
self._valid = valid
21+
self._value = value
22+
23+
def __eq__(self, scalar: Scalar) -> bool:
24+
if scalar is None:
25+
return False
26+
if type(scalar) == Bool:
27+
return self._value == scalar._value and self._valid == scalar._valid
28+
return False
29+
30+
def set(self, value: Any):
31+
if value is None:
32+
return
33+
34+
if type(value) == bool:
35+
self._value = value
36+
elif type(value) == str:
37+
self._value = parse_string_to_bool(value)
38+
else:
39+
raise ScalarInvalidTypeError("Invalid type for Bool scalar")
40+
41+
self._valid = True

cloudquery/sdk/scalar/date32.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError
3+
from datetime import datetime, time
4+
from typing import Any
5+
6+
class Date32(Scalar):
7+
def __init__(self, valid: bool = False, value: bool = False) -> None:
8+
self._valid = valid
9+
self._value = value
10+
11+
def __eq__(self, scalar: Scalar) -> bool:
12+
if scalar is None:
13+
return False
14+
if type(scalar) == Date32:
15+
return self._value == scalar._value and self._valid == scalar._valid
16+
return False
17+
18+
def set(self, value: Any):
19+
if value is None:
20+
return
21+
22+
if type(value) == datetime:
23+
self._value = value
24+
elif type(value) == str:
25+
self._value = datetime.strptime(value, "%Y-%m-%d")
26+
elif type(value) == time:
27+
self._value = datetime.combine(datetime.today(), value)
28+
else:
29+
raise ScalarInvalidTypeError("Invalid type for Bool scalar")
30+
31+
self._valid = True

cloudquery/sdk/scalar/float64.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError
2+
3+
class Float64(Scalar):
4+
def __init__(self, valid: bool = False, value: float = None):
5+
self._valid = valid
6+
self._value = value
7+
8+
def __eq__(self, scalar: Scalar) -> bool:
9+
if scalar is None:
10+
return False
11+
if type(scalar) == Float64:
12+
return self._value == scalar._value and self._valid == scalar._valid
13+
return False
14+
15+
def set(self, value):
16+
if value is None:
17+
return
18+
19+
if type(value) == int:
20+
self._value = float(value)
21+
elif type(value) == float:
22+
self._value = value
23+
elif type(value) == str:
24+
try:
25+
self._value = float(value)
26+
except ValueError:
27+
raise ScalarInvalidTypeError("Invalid type for Float64 scalar")
28+
else:
29+
raise ScalarInvalidTypeError("Invalid type for Binary scalar")
30+
self._valid = True

cloudquery/sdk/scalar/int64.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from cloudquery.sdk.scalar import Scalar, ScalarInvalidTypeError
2+
3+
class Int64(Scalar):
4+
def __init__(self, valid: bool = False, value: float = None):
5+
self._valid = valid
6+
self._value = value
7+
8+
def __eq__(self, scalar: Scalar) -> bool:
9+
if scalar is None:
10+
return False
11+
if type(scalar) == Int64:
12+
return self._value == scalar._value and self._valid == scalar._valid
13+
return False
14+
15+
def set(self, value):
16+
if value is None:
17+
return
18+
19+
if type(value) == int:
20+
self._value = value
21+
elif type(value) == float:
22+
self._value = int(value)
23+
elif type(value) == str:
24+
try:
25+
self._value = int(value)
26+
except ValueError:
27+
raise ScalarInvalidTypeError("Invalid type for Int64 scalar")
28+
else:
29+
raise ScalarInvalidTypeError("Invalid type for Int64 scalar")
30+
self._valid = True

cloudquery/sdk/scalar/scalar.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pyarrow as pa
2+
from .int64 import Int64
3+
4+
class ScalarInvalidTypeError(Exception):
5+
pass
6+
7+
class Scalar:
8+
@property
9+
def is_valid(self) -> bool:
10+
return self._valid
11+
12+
13+
class ScalarFactory:
14+
def __init__(self):
15+
self._type_map = {
16+
pa.int64: lambda dt: Int64(),
17+
}
18+
19+
def new_scalar(self, dt):
20+
if dt in self._type_map:
21+
return self._type_map[dt]()
22+
else:
23+
raise ScalarInvalidTypeError("Invalid type for scalar")

cloudquery/sdk/scheduler/scheduler.py

Lines changed: 49 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,75 +2,19 @@
22
from typing import List, Generator, Any
33
import queue
44
import time
5-
from cloudquery.sdk.schema import Table
5+
from cloudquery.sdk.schema import Table, Resource
66
from cloudquery.sdk.message import SyncMessage, SyncInsertMessage, SyncMigrateMessage
7-
import concurrent.futures
7+
from concurrent import futures
88
from typing import Generator
9-
10-
# This is all WIP
11-
class Task:
12-
def __init__(self, fetcher, parent_item):
13-
self._fetcher = fetcher
14-
self._parent_item = parent_item
15-
16-
class Item:
17-
def __init__(self, fetcher, parent_item):
18-
self._fetcher = fetcher
19-
self._parent_item = parent_item
20-
21-
class Fetcher:
22-
def __init__(self, relations: List[Any]):
23-
self._relations = relations
24-
25-
def get(self, parent_item) -> Generator[SyncInsertMessage]:
26-
for i in range(10):
27-
yield SyncInsertMessage(None)
28-
29-
def process_item(self, item: Any) -> Generator[SyncInsertMessage]:
30-
pass
31-
32-
def transform_item(self, item: Any) -> Any:
33-
pass
34-
35-
36-
def worker(fetcher: Fetcher, parent_item: Any):
37-
for arr in fetcher.get(parent_item):
38-
for res in arr:
39-
fetcher.get()
40-
41-
# while True:
42-
# item = q.get()
43-
# if item is None:
44-
# break
45-
# do_work(item)
46-
# q.task_done()
47-
48-
def worker_task(q, worker_id):
49-
for i in range(5):
50-
time.sleep(0.1) # Simulate work
51-
task = (worker_id, i)
52-
print(f"Worker {worker_id} created task {task}")
53-
q.put(task)
54-
55-
def main_task(q: queue.Queue):
56-
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
57-
while True:
58-
try:
59-
task = q.get(timeout=1) # Wait up to 1 second for a task.
60-
except queue.Empty:
61-
print("All tasks completed.")
62-
break # Exit while loop if no more tasks.
63-
print(f"Main: Running task {task}")
64-
executor.submit(run_task, task)
65-
q.task_done()
66-
67-
def run_task(task):
68-
69-
print(f"Running task {task}")
70-
9+
from .table_resolver import TableResolver
7110

7211
QUEUE_PER_WORKER = 100
7312

13+
class ThreadPoolExecutorWithQueueSizeLimit(futures.ThreadPoolExecutor):
14+
def __init__(self, maxsize, *args, **kwargs):
15+
super(ThreadPoolExecutorWithQueueSizeLimit, self).__init__(*args, **kwargs)
16+
self._work_queue = queue.Queue(maxsize=maxsize)
17+
7418
class Scheduler:
7519
def __init__(self, concurrency: int, queue_size: int = 0, max_depth : int = 3):
7620
self._queue = queue.Queue()
@@ -82,33 +26,50 @@ def __init__(self, concurrency: int, queue_size: int = 0, max_depth : int = 3):
8226
if max_depth <= 0:
8327
raise ValueError("max_depth must be greater than 0")
8428
self._queue_size = queue_size if queue_size > 0 else concurrency * QUEUE_PER_WORKER
85-
self._pools = []
86-
self._queues = []
29+
self._pools : List[ThreadPoolExecutorWithQueueSizeLimit] = []
8730
current_depth_concurrency = concurrency
8831
current_depth_queue_size = queue_size
89-
for i in range(max_depth + 1):
90-
self._queues.append(queue.Queue(maxsize=current_depth_queue_size))
91-
self._pools.append(concurrent.futures.ThreadPoolExecutor(max_workers=current_depth_concurrency))
32+
for _ in range(max_depth + 1):
33+
self._pools.append(ThreadPoolExecutorWithQueueSizeLimit(maxsize=current_depth_queue_size,max_workers=current_depth_concurrency))
9234
current_depth_concurrency = current_depth_concurrency // 2 if current_depth_concurrency > 1 else 1
9335
current_depth_queue_size = current_depth_queue_size // 2 if current_depth_queue_size > 1 else 1
9436

95-
def worker(self, max_depth: int):
96-
while True:
97-
task = self._queues[max_depth].get()
98-
if task is None:
99-
break
100-
self._pools[max_depth].submit(*task)
101-
102-
def table_resolver(self, table: Table, client, res: queue.Queue):
103-
for resource in table.resolve(client):
104-
pass
105-
# task.resolve
37+
def resolve_resource(self, resolver: TableResolver, client, parent: Resource, item: Any) -> Resource:
38+
resource = Resource(resolver.table, None, item)
39+
resolver.pre_resource_resolve(client, resource)
40+
for column in resolver.table.columns:
41+
resolver.resolve_column(client, resource, column.name)
42+
resolver.post_resource_resolve(client, resource)
43+
return resource
44+
45+
def resolve_table(self, resolver: TableResolver, client, parent_item: Any, res: queue.Queue):
46+
for item in resolver.resolve(client, parent_item):
47+
resource = self.resolve_resource(resolver, client, parent_item)
48+
res.put(SyncInsertMessage(resource))
49+
res.put(None)
10650

107-
def sync(self, client, tables: List[Table], res: queue.Queue, deterministic_cq_id=False):
108-
for table in tables:
109-
res.put(SyncMigrateMessage(record=table.to_arrow_schemas()))
110-
for table in tables:
111-
clients = table.multiplex(client)
112-
for client in clients:
113-
self._queues[0].put((table.resolver, client, res))
114-
self._queues[0].put(None)
51+
def _sync(self, client, resolvers: List[TableResolver], res: queue.Queue, deterministic_cq_id=False):
52+
internal_res = queue.Queue()
53+
for resolver in resolvers:
54+
clients = resolver.multiplex(client)
55+
for client in clients:
56+
self._pools[0].submit(self.resolve_table, resolver, client, None, internal_res)
57+
while True:
58+
message = internal_res.get()
59+
if message is None:
60+
break
61+
res.put(message)
62+
res.put(None)
63+
64+
def sync(self, client, resolvers: List[TableResolver], deterministic_cq_id=False) -> Generator[SyncMessage]:
65+
res = queue.Queue()
66+
for resolver in resolvers:
67+
yield SyncMigrateMessage(record=resolver.table.to_arrow_schemas())
68+
thread = futures.ThreadPoolExecutor()
69+
thread.submit(self._sync, client, resolvers, res, deterministic_cq_id)
70+
while True:
71+
message = res.get()
72+
if message is None:
73+
break
74+
yield message
75+
thread.shutdown()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
from cloudquery.sdk.schema.table import Table
3+
from cloudquery.sdk.schema import Resource
4+
from typing import Any
5+
6+
class TableResolver:
7+
def __init__(self, table: Table) -> None:
8+
self._table = table
9+
10+
@property
11+
def table(self):
12+
return self._table
13+
14+
def multiplex(self, client):
15+
return [client]
16+
17+
def resolve(self, client, parent_resource) -> Any:
18+
raise NotImplementedError()
19+
20+
def pre_resource_resolve(self, client, resource):
21+
return
22+
23+
def resolve_column(self, client, resource: Resource, column_name: str):
24+
if hasattr(resource.item, column_name):
25+
resource.set(column_name, resource.item[column_name])
26+
27+
def post_resource_resolve(self, client, resource):
28+
return

cloudquery/sdk/schema/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from .column import Column
22
from .table import Table, tables_to_arrow_schemas
3+
from .resource import Resource
4+
# from .table_resolver import TableReso

0 commit comments

Comments
 (0)