Skip to content

Commit ba55e3c

Browse files
committed
Support polars cloud
1 parent 0e6eb96 commit ba55e3c

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

queries/polars/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from queries.common_utils import execute_all
1+
from queries.polars.utils import execute_all
22

33
if __name__ == "__main__":
4-
execute_all("polars")
4+
execute_all()

queries/polars/utils.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,37 @@
44
from typing import Literal
55

66
import polars as pl
7-
87
from queries.common_utils import (
98
check_query_result_pl,
9+
execute_all as common_execute_all,
1010
get_table_path,
1111
run_query_generic,
1212
)
13+
from queries.polars.cloud_utils import get_compute_context, stop_compute_context
1314
from settings import Settings
1415

1516
settings = Settings()
1617

1718

19+
def execute_all() -> None:
20+
if not settings.run.polars_cloud:
21+
return execute_all("polars")
22+
23+
# for polars cloud we have to create the compute context,
24+
# reuse it across the queries, and stop it in the end
25+
ctx = get_compute_context(log_create=True, log_reuse=True)
26+
try:
27+
common_execute_all("polars")
28+
finally:
29+
print(f"Stopping compute context: {ctx._compute_id}")
30+
stop_compute_context(ctx)
31+
32+
1833
def _scan_ds(table_name: str) -> pl.LazyFrame:
1934
path = get_table_path(table_name)
35+
# pathlib.Path normalizes consecutive slashes, unless Path.from_uri is used (Python >= 3.13)
36+
if isinstance(path, pathlib.Path) and str(path).startswith("s3:/") and not str(path).startswith("s3://"):
37+
path = f"s3://{str(path)[4:]}"
2038

2139
if settings.run.io_type == "skip":
2240
return pl.read_parquet(path, rechunk=True).lazy()
@@ -161,28 +179,12 @@ def run_query(query_number: int, lf: pl.LazyFrame) -> None:
161179
if cloud:
162180
import os
163181

164-
import polars_cloud as pc
165-
166182
os.environ["POLARS_SKIP_CLIENT_CHECK"] = "1"
167183

168-
class PatchedComputeContext(pc.ComputeContext):
169-
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
170-
self._interactive = True
171-
self._compute_address = "localhost:5051"
172-
self._compute_public_key = b""
173-
self._compute_id = "1" # type: ignore[assignment]
174-
175-
def get_status(self: pc.ComputeContext) -> pc.ComputeContextStatus:
176-
"""Get the status of the compute cluster."""
177-
return pc.ComputeContextStatus.RUNNING
178-
179-
pc.ComputeContext.__init__ = PatchedComputeContext.__init__ # type: ignore[assignment]
180-
pc.ComputeContext.get_status = PatchedComputeContext.get_status # type: ignore[method-assign]
184+
ctx = get_compute_context(create_if_no_reuse=False)
181185

182186
def query(): # type: ignore[no-untyped-def]
183-
result = pc.spawn(
184-
lf, dst="file:///tmp/dst/", distributed=True
185-
).await_result()
187+
result = lf.remote(context=ctx).distributed().collect()
186188

187189
if settings.run.show_results:
188190
print(result.plan())

settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ class Run(BaseSettings):
4949
"cuda", "cuda-pool", "managed", "managed-pool", "cuda-async"
5050
] = "cuda-async"
5151

52+
polars_cloud_cpus: int | None = 1 ## CPUs per node
53+
polars_cloud_memory: int | None = 2 # GB per node
54+
polars_cloud_instance_type: str | None = None # use instance_type instead of cpus and memory, e.g. "t2.micro"
55+
polars_cloud_cluster_size: int = 1 ## nodes in the cluster
56+
polars_cloud_workspace: str | None = None
57+
5258
modin_memory: int = 8_000_000_000 # Tune as needed for optimal performance
5359

5460
spark_driver_memory: str = "2g" # Tune as needed for optimal performance

0 commit comments

Comments
 (0)