|
4 | 4 | from typing import Literal |
5 | 5 |
|
6 | 6 | import polars as pl |
7 | | - |
8 | 7 | from queries.common_utils import ( |
9 | 8 | check_query_result_pl, |
| 9 | + execute_all as common_execute_all, |
10 | 10 | get_table_path, |
11 | 11 | run_query_generic, |
12 | 12 | ) |
| 13 | +from queries.polars.cloud_utils import get_compute_context, stop_compute_context |
13 | 14 | from settings import Settings |
14 | 15 |
|
15 | 16 | settings = Settings() |
16 | 17 |
|
17 | 18 |
|
| 19 | +def execute_all() -> None: |
| 20 | + if not settings.run.polars_cloud: |
| 21 | + return execute_all("polars") |
| 22 | + |
| 23 | + # for polars cloud we have to create the compute context, |
| 24 | + # reuse it across the queries, and stop it in the end |
| 25 | + ctx = get_compute_context(log_create=True, log_reuse=True) |
| 26 | + try: |
| 27 | + common_execute_all("polars") |
| 28 | + finally: |
| 29 | + print(f"Stopping compute context: {ctx._compute_id}") |
| 30 | + stop_compute_context(ctx) |
| 31 | + |
| 32 | + |
18 | 33 | def _scan_ds(table_name: str) -> pl.LazyFrame: |
19 | 34 | path = get_table_path(table_name) |
| 35 | + # pathlib.Path normalizes consecutive slashes, unless Path.from_uri is used (Python >= 3.13) |
| 36 | + if isinstance(path, pathlib.Path) and str(path).startswith("s3:/") and not str(path).startswith("s3://"): |
| 37 | + path = f"s3://{str(path)[4:]}" |
20 | 38 |
|
21 | 39 | if settings.run.io_type == "skip": |
22 | 40 | return pl.read_parquet(path, rechunk=True).lazy() |
@@ -161,28 +179,12 @@ def run_query(query_number: int, lf: pl.LazyFrame) -> None: |
161 | 179 | if cloud: |
162 | 180 | import os |
163 | 181 |
|
164 | | - import polars_cloud as pc |
165 | | - |
166 | 182 | os.environ["POLARS_SKIP_CLIENT_CHECK"] = "1" |
167 | 183 |
|
168 | | - class PatchedComputeContext(pc.ComputeContext): |
169 | | - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] |
170 | | - self._interactive = True |
171 | | - self._compute_address = "localhost:5051" |
172 | | - self._compute_public_key = b"" |
173 | | - self._compute_id = "1" # type: ignore[assignment] |
174 | | - |
175 | | - def get_status(self: pc.ComputeContext) -> pc.ComputeContextStatus: |
176 | | - """Get the status of the compute cluster.""" |
177 | | - return pc.ComputeContextStatus.RUNNING |
178 | | - |
179 | | - pc.ComputeContext.__init__ = PatchedComputeContext.__init__ # type: ignore[assignment] |
180 | | - pc.ComputeContext.get_status = PatchedComputeContext.get_status # type: ignore[method-assign] |
| 184 | + ctx = get_compute_context(create_if_no_reuse=False) |
181 | 185 |
|
182 | 186 | def query(): # type: ignore[no-untyped-def] |
183 | | - result = pc.spawn( |
184 | | - lf, dst="file:///tmp/dst/", distributed=True |
185 | | - ).await_result() |
| 187 | + result = lf.remote(context=ctx).distributed().collect() |
186 | 188 |
|
187 | 189 | if settings.run.show_results: |
188 | 190 | print(result.plan()) |
|
0 commit comments