diff --git a/queries/polars/__main__.py b/queries/polars/__main__.py index 0fc97f5..13b7ae9 100644 --- a/queries/polars/__main__.py +++ b/queries/polars/__main__.py @@ -1,4 +1,4 @@ -from queries.common_utils import execute_all +from queries.polars.utils import execute_all if __name__ == "__main__": - execute_all("polars") + execute_all() diff --git a/queries/polars/cloud_utils.py b/queries/polars/cloud_utils.py new file mode 100644 index 0000000..68e2207 --- /dev/null +++ b/queries/polars/cloud_utils.py @@ -0,0 +1,87 @@ +import base64 +import json +import pathlib +from uuid import UUID + +import polars_cloud as pc + +from settings import Settings + +settings = Settings() + + +def reuse_compute_context(filename: str, log_reuse: bool) -> pc.ComputeContext | None: + with pathlib.Path(filename).open("r", encoding="utf8") as r: + context_args = json.load(r) + + required_keys = ["workspace_id", "compute_id"] + for key in required_keys: + assert key in context_args, f"Key {key} not in {filename}" + if log_reuse: + print(f"Reusing existing compute context: {context_args['compute_id']}") + context_args = {key: UUID(context_args.get(key)) for key in required_keys} + try: + ctx = pc.ComputeContext.connect(**context_args) + ctx.start(wait=True) + assert ctx.get_status() == pc.ComputeContextStatus.RUNNING + except RuntimeError as e: + print(f"Cannot reuse existing compute context: {e.args}") + return None + return ctx + + +def get_compute_context_args() -> dict[str, str | int]: + return { + key: value + for key, value in { + "cpus": settings.run.polars_cloud_cpus, + "memory": settings.run.polars_cloud_memory, + "instance_type": settings.run.polars_cloud_instance_type, + "cluster_size": settings.run.polars_cloud_cluster_size, + "workspace": settings.run.polars_cloud_workspace, + }.items() + if value is not None + } + + +def get_compute_context_filename(context_args: dict[str, str | int]) -> str: + hash = base64.b64encode(str(context_args).encode("utf-8")).decode() + return f".polars-cloud-compute-context-{hash}.json" + + +def get_compute_context( + *, + create_if_no_reuse: bool = True, + log_create: bool = False, + log_reuse: bool = False, +) -> pc.ComputeContext: + context_args = get_compute_context_args() + context_filename = get_compute_context_filename(context_args) + if pathlib.Path(context_filename).is_file(): + ctx = reuse_compute_context(context_filename, log_reuse) + if ctx: + return ctx + + # start new compute context + if not create_if_no_reuse: + msg = "Cannot reuse compute context" + raise RuntimeError(msg) + if log_create: + print(f"Starting new compute context: {context_args}") + ctx = pc.ComputeContext(**context_args) # type: ignore[arg-type] + ctx.start(wait=True) + assert ctx.get_status() == pc.ComputeContextStatus.RUNNING + context_args = { + "workspace_id": str(ctx.workspace.id), + "compute_id": str(ctx._compute_id), + } + with pathlib.Path(context_filename).open("w", encoding="utf8") as w: + json.dump(context_args, w) + return ctx + + +def stop_compute_context(ctx: pc.ComputeContext) -> None: + ctx.stop(wait=True) + context_args = get_compute_context_args() + context_filename = get_compute_context_filename(context_args) + pathlib.Path(context_filename).unlink(missing_ok=True) diff --git a/queries/polars/utils.py b/queries/polars/utils.py index 8bbc134..9e0d1ab 100644 --- a/queries/polars/utils.py +++ b/queries/polars/utils.py @@ -10,22 +10,45 @@ get_table_path, run_query_generic, ) +from queries.common_utils import ( + execute_all as common_execute_all, +) +from queries.polars.cloud_utils import get_compute_context, stop_compute_context from settings import Settings settings = Settings() +def execute_all() -> None: + if not settings.run.polars_cloud: + return common_execute_all("polars") + + # for polars cloud we have to create the compute context, + # reuse it across the queries, and stop it in the end + ctx = get_compute_context(log_create=True, log_reuse=True) + try: + common_execute_all("polars") + finally: + print(f"Stopping compute context: {ctx._compute_id}") + stop_compute_context(ctx) + + def _scan_ds(table_name: str) -> pl.LazyFrame: path = get_table_path(table_name) + # pathlib.Path normalizes consecutive slashes, + # unless Path.from_uri is used (Python >= 3.13) + path_str = str(path) + if path_str.startswith("s3:/") and not path_str.startswith("s3://"): + path_str = f"s3://{str(path)[4:]}" if settings.run.io_type == "skip": - return pl.read_parquet(path, rechunk=True).lazy() + return pl.read_parquet(path_str, rechunk=True).lazy() if settings.run.io_type == "parquet": - return pl.scan_parquet(path) + return pl.scan_parquet(path_str) elif settings.run.io_type == "feather": - return pl.scan_ipc(path) + return pl.scan_ipc(path_str) elif settings.run.io_type == "csv": - return pl.scan_csv(path, try_parse_dates=True) + return pl.scan_csv(path_str, try_parse_dates=True) else: msg = f"unsupported file type: {settings.run.io_type!r}" raise ValueError(msg) @@ -161,32 +184,12 @@ def run_query(query_number: int, lf: pl.LazyFrame) -> None: if cloud: import os - import polars_cloud as pc - os.environ["POLARS_SKIP_CLIENT_CHECK"] = "1" - class PatchedComputeContext(pc.ComputeContext): - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - self._interactive = True - self._compute_address = "localhost:5051" - self._compute_public_key = b"" - self._compute_id = "1" # type: ignore[assignment] - - def get_status(self: pc.ComputeContext) -> pc.ComputeContextStatus: - """Get the status of the compute cluster.""" - return pc.ComputeContextStatus.RUNNING - - pc.ComputeContext.__init__ = PatchedComputeContext.__init__ # type: ignore[assignment] - pc.ComputeContext.get_status = PatchedComputeContext.get_status # type: ignore[method-assign] + ctx = get_compute_context(create_if_no_reuse=False) def query(): # type: ignore[no-untyped-def] - result = pc.spawn( - lf, dst="file:///tmp/dst/", distributed=True - ).await_result() - - if settings.run.show_results: - print(result.plan()) - return result.lazy().collect() + return lf.remote(context=ctx).distributed().collect() else: query = partial( lf.collect, diff --git a/settings.py b/settings.py index ac714d0..de2fe9d 100644 --- a/settings.py +++ b/settings.py @@ -49,6 +49,13 @@ class Run(BaseSettings): "cuda", "cuda-pool", "managed", "managed-pool", "cuda-async" ] = "cuda-async" + polars_cloud_cpus: int | None = 1 ## CPUs per node + polars_cloud_memory: int | None = 2 # GB per node + # use instance_type instead of cpus and memory, e.g. "t2.micro" + polars_cloud_instance_type: str | None = None + polars_cloud_cluster_size: int = 1 ## nodes in the cluster + polars_cloud_workspace: str | None = None + modin_memory: int = 8_000_000_000 # Tune as needed for optimal performance spark_driver_memory: str = "2g" # Tune as needed for optimal performance