Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions queries/polars/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from queries.common_utils import execute_all
from queries.polars.utils import execute_all

if __name__ == "__main__":
execute_all("polars")
execute_all()
87 changes: 87 additions & 0 deletions queries/polars/cloud_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import base64
import json
import pathlib
from uuid import UUID

import polars_cloud as pc

from settings import Settings

settings = Settings()


def reuse_compute_context(filename: str, log_reuse: bool) -> pc.ComputeContext | None:
with pathlib.Path(filename).open("r", encoding="utf8") as r:
context_args = json.load(r)

required_keys = ["workspace_id", "compute_id"]
for key in required_keys:
assert key in context_args, f"Key {key} not in {filename}"
if log_reuse:
print(f"Reusing existing compute context: {context_args['compute_id']}")
context_args = {key: UUID(context_args.get(key)) for key in required_keys}
try:
ctx = pc.ComputeContext.connect(**context_args)
ctx.start(wait=True)
assert ctx.get_status() == pc.ComputeContextStatus.RUNNING
except RuntimeError as e:
print(f"Cannot reuse existing compute context: {e.args}")
return None
return ctx


def get_compute_context_args() -> dict[str, str | int]:
return {
key: value
for key, value in {
"cpus": settings.run.polars_cloud_cpus,
"memory": settings.run.polars_cloud_memory,
"instance_type": settings.run.polars_cloud_instance_type,
"cluster_size": settings.run.polars_cloud_cluster_size,
"workspace": settings.run.polars_cloud_workspace,
}.items()
if value is not None
}


def get_compute_context_filename(context_args: dict[str, str | int]) -> str:
hash = base64.b64encode(str(context_args).encode("utf-8")).decode()
return f".polars-cloud-compute-context-{hash}.json"


def get_compute_context(
*,
create_if_no_reuse: bool = True,
log_create: bool = False,
log_reuse: bool = False,
) -> pc.ComputeContext:
context_args = get_compute_context_args()
context_filename = get_compute_context_filename(context_args)
if pathlib.Path(context_filename).is_file():
ctx = reuse_compute_context(context_filename, log_reuse)
if ctx:
return ctx

# start new compute context
if not create_if_no_reuse:
msg = "Cannot reuse compute context"
raise RuntimeError(msg)
if log_create:
print(f"Starting new compute context: {context_args}")
ctx = pc.ComputeContext(**context_args) # type: ignore[arg-type]
ctx.start(wait=True)
assert ctx.get_status() == pc.ComputeContextStatus.RUNNING
context_args = {
"workspace_id": str(ctx.workspace.id),
"compute_id": str(ctx._compute_id),
}
with pathlib.Path(context_filename).open("w", encoding="utf8") as w:
json.dump(context_args, w)
return ctx


def stop_compute_context(ctx: pc.ComputeContext) -> None:
ctx.stop(wait=True)
context_args = get_compute_context_args()
context_filename = get_compute_context_filename(context_args)
pathlib.Path(context_filename).unlink(missing_ok=True)
55 changes: 29 additions & 26 deletions queries/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,45 @@
get_table_path,
run_query_generic,
)
from queries.common_utils import (
execute_all as common_execute_all,
)
from queries.polars.cloud_utils import get_compute_context, stop_compute_context
from settings import Settings

settings = Settings()


def execute_all() -> None:
if not settings.run.polars_cloud:
return common_execute_all("polars")

# for polars cloud we have to create the compute context,
# reuse it across the queries, and stop it in the end
ctx = get_compute_context(log_create=True, log_reuse=True)
try:
common_execute_all("polars")
finally:
print(f"Stopping compute context: {ctx._compute_id}")
stop_compute_context(ctx)


def _scan_ds(table_name: str) -> pl.LazyFrame:
path = get_table_path(table_name)
# pathlib.Path normalizes consecutive slashes,
# unless Path.from_uri is used (Python >= 3.13)
path_str = str(path)
if path_str.startswith("s3:/") and not path_str.startswith("s3://"):
path_str = f"s3://{str(path)[4:]}"

if settings.run.io_type == "skip":
return pl.read_parquet(path, rechunk=True).lazy()
return pl.read_parquet(path_str, rechunk=True).lazy()
if settings.run.io_type == "parquet":
return pl.scan_parquet(path)
return pl.scan_parquet(path_str)
elif settings.run.io_type == "feather":
return pl.scan_ipc(path)
return pl.scan_ipc(path_str)
elif settings.run.io_type == "csv":
return pl.scan_csv(path, try_parse_dates=True)
return pl.scan_csv(path_str, try_parse_dates=True)
else:
msg = f"unsupported file type: {settings.run.io_type!r}"
raise ValueError(msg)
Expand Down Expand Up @@ -161,32 +184,12 @@ def run_query(query_number: int, lf: pl.LazyFrame) -> None:
if cloud:
import os

import polars_cloud as pc

os.environ["POLARS_SKIP_CLIENT_CHECK"] = "1"

class PatchedComputeContext(pc.ComputeContext):
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
self._interactive = True
self._compute_address = "localhost:5051"
self._compute_public_key = b""
self._compute_id = "1" # type: ignore[assignment]

def get_status(self: pc.ComputeContext) -> pc.ComputeContextStatus:
"""Get the status of the compute cluster."""
return pc.ComputeContextStatus.RUNNING

pc.ComputeContext.__init__ = PatchedComputeContext.__init__ # type: ignore[assignment]
pc.ComputeContext.get_status = PatchedComputeContext.get_status # type: ignore[method-assign]
ctx = get_compute_context(create_if_no_reuse=False)

def query(): # type: ignore[no-untyped-def]
result = pc.spawn(
lf, dst="file:///tmp/dst/", distributed=True
).await_result()

if settings.run.show_results:
print(result.plan())
return result.lazy().collect()
return lf.remote(context=ctx).distributed().collect()
else:
query = partial(
lf.collect,
Expand Down
7 changes: 7 additions & 0 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class Run(BaseSettings):
"cuda", "cuda-pool", "managed", "managed-pool", "cuda-async"
] = "cuda-async"

polars_cloud_cpus: int | None = 1 ## CPUs per node
polars_cloud_memory: int | None = 2 # GB per node
# use instance_type instead of cpus and memory, e.g. "t2.micro"
polars_cloud_instance_type: str | None = None
polars_cloud_cluster_size: int = 1 ## nodes in the cluster
polars_cloud_workspace: str | None = None

modin_memory: int = 8_000_000_000 # Tune as needed for optimal performance

spark_driver_memory: str = "2g" # Tune as needed for optimal performance
Expand Down