Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions tpch/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@

import narwhals as nw
from narwhals.exceptions import NarwhalsError
from tpch import constants
from tpch.constants import (
DATABASE_TABLE_NAMES,
LOGGER_NAME,
QUERIES_PACKAGE,
QUERY_IDS,
_scale_factor_dir,
)
Comment on lines +9 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this an intentional change, or lost in the conflicts?

In (134875f) I switched to constants so that the imports didn't take up so many lines 😢

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see! Sorry feel free to revert it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep it for now, this next branch has waaaay more of a negative diff 😄

tpch/refactor-cli-sf-folders...tpch/refactor-cli-choices


if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -51,32 +57,48 @@ def scan(self, source: FileSource) -> nw.LazyFrame[Any]:

class Query:
id: QueryID
paths: tuple[Path, ...]
table_names: tuple[str, ...]
_into_xfails: tuple[tuple[Predicate, str, XFailRaises], ...]
_into_skips: tuple[tuple[Predicate, str], ...]

def __init__(self, query_id: QueryID, paths: tuple[Path, ...]) -> None:
def __init__(self, query_id: QueryID, table_names: tuple[str, ...]) -> None:
self.id = query_id
self.paths = paths
self.table_names = table_names
self._into_xfails = ()
self._into_skips = ()

def __repr__(self) -> str:
return self.id

def inputs(
self, backend: Backend, scale_factor: float
) -> tuple[nw.LazyFrame[Any], ...]:
"""Get the frame inputs for this query at the given scale factor."""
sf_dir = _scale_factor_dir(scale_factor)
return tuple(
backend.scan((sf_dir / f"{name}.parquet").as_posix())
for name in self.table_names
)

def expected(self, scale_factor: float) -> pl.DataFrame:
sf_dir = _scale_factor_dir(scale_factor)
return pl.read_parquet(sf_dir / f"result_{self}.parquet")

def execute(
self, backend: Backend, scale_factor: float, request: pytest.FixtureRequest
) -> None:
self._apply_skips(backend, scale_factor)
data = tuple(backend.scan(fp.as_posix()) for fp in self.paths)
data = self.inputs(backend=backend, scale_factor=scale_factor)
query = self._import_module().query

try:
result = query(*data).lazy().collect("polars").to_polars()
except NarwhalsError as exc:
msg = f"Query [{self}-{backend}] ({scale_factor=}) failed with the following error in Narwhals:\n{exc}"
raise RuntimeError(msg) from exc

self._apply_xfails(backend, scale_factor, request)
expected = pl.read_parquet(constants.DATA_DIR / f"result_{self}.parquet")
expected = self.expected(scale_factor=scale_factor)
try:
pl_assert_frame_equal(expected, result, check_dtypes=False)
except AssertionError as exc:
Expand Down Expand Up @@ -107,11 +129,11 @@ def _apply_xfails(
request.applymarker(mark)

def _import_module(self) -> QueryModule:
result: Any = import_module(f"{constants.QUERIES_PACKAGE}.{self}")
result: Any = import_module(f"{QUERIES_PACKAGE}.{self}")
return result


logger = logging.getLogger(constants.LOGGER_NAME)
logger = logging.getLogger(LOGGER_NAME)


class TableLogger:
Expand All @@ -125,11 +147,11 @@ def __init__(self, file_names: Iterable[str]) -> None:

@staticmethod
def answers() -> TableLogger:
return TableLogger(f"result_{qid}.parquet" for qid in constants.QUERY_IDS)
return TableLogger(f"result_{qid}.parquet" for qid in QUERY_IDS)

@staticmethod
def database() -> TableLogger:
return TableLogger(f"{t}.parquet" for t in constants.DATABASE_TABLE_NAMES)
return TableLogger(f"{t}.parquet" for t in DATABASE_TABLE_NAMES)

def __enter__(self) -> Self:
# header
Expand Down
13 changes: 8 additions & 5 deletions tpch/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING, get_args

Expand All @@ -11,12 +12,14 @@
REPO_ROOT = Path(__file__).parent.parent
TPCH_DIR = REPO_ROOT / "tpch"
DATA_DIR = TPCH_DIR / "data"
METADATA_PATH = DATA_DIR / "metadata.csv"
"""For reflection in tests.

E.g. if we *know* the query is not valid for a given `scale_factor`,
then we can determine if a failure is expected.
"""

@cache
def _scale_factor_dir(scale_factor: float) -> Path:
"""Get the data directory for a specific scale factor."""
return DATA_DIR / f"sf{scale_factor}"


DATABASE_TABLE_NAMES = (
"lineitem",
"customer",
Expand Down
98 changes: 30 additions & 68 deletions tpch/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import argparse
import datetime as dt
import logging
import os
from functools import cache
Expand All @@ -21,8 +20,8 @@
DATABASE_TABLE_NAMES,
GLOBS,
LOGGER_NAME,
METADATA_PATH,
QUERY_IDS,
_scale_factor_dir,
)

if TYPE_CHECKING:
Expand All @@ -42,13 +41,14 @@ def read_fmt_schema(fp: Path) -> str:
return f"- {fp.name}\n" + "\n".join(f" - {k:<20}: {v}" for k, v in schema)


def show_schemas(artifact: Artifact, /) -> None:
def show_schemas(artifact: Artifact, scale_factor: float, /) -> None:
if not logger.isEnabledFor(logging.DEBUG):
return
pattern = GLOBS[artifact]
paths = sorted(DATA_DIR.glob(pattern))
sf_dir = _scale_factor_dir(scale_factor)
paths = sorted(sf_dir.glob(pattern))
if not paths:
msg = f"Found no matching paths for {pattern!r} in {DATA_DIR.as_posix()}"
msg = f"Found no matching paths for {pattern!r} in {sf_dir.as_posix()}"
raise NotImplementedError(msg)
msg = "\n".join(read_fmt_schema(fp) for fp in paths)
logger.debug("Schemas (%s):\n%s", artifact, msg)
Expand Down Expand Up @@ -103,24 +103,27 @@ def load_tpch_extension(con: Con) -> Con:


def generate_tpch_database(con: Con, scale_factor: float) -> Con:
logger.info("Generating data with scale_factor=%s", scale_factor)
logger.info("Generating data for scale_factor=%s", scale_factor)
con.sql(SQL_DBGEN.format(scale_factor))
logger.info("Finished generating data.")
return con


def write_tpch_database(con: Con) -> Con:
logger.info("Writing data to: %s", DATA_DIR.as_posix())
def write_tpch_database(con: Con, scale_factor: float) -> Con:
sf_dir = _scale_factor_dir(scale_factor)
sf_dir.mkdir(parents=True, exist_ok=True)
logger.info("Writing data to: %s", sf_dir.as_posix())
with TableLogger.database() as tbl_logger:
for t in DATABASE_TABLE_NAMES:
path = DATA_DIR / f"{t}.parquet"
path = sf_dir / f"{t}.parquet"
con.sql(SQL_FROM.format(t)).pl().cast(cast_map()).write_parquet(path)
tbl_logger.log_row(path)
show_schemas("database")
show_schemas("database", scale_factor)
return con


def write_tpch_answers(con: Con) -> Con:
def write_tpch_answers(con: Con, scale_factor: float) -> Con:
sf_dir = _scale_factor_dir(scale_factor)
logger.info("Executing tpch queries for answers")
with TableLogger.answers() as tbl_logger:
for query_id in QUERY_IDS:
Expand All @@ -129,76 +132,35 @@ def write_tpch_answers(con: Con) -> Con:
df = con.sql(query).pl().cast(cast_map())
if fix := FIX_ANSWERS.get(query_id):
df = fix(df)
path = sf_dir / f"result_{query_id}.parquet"
df.write_parquet(path)
tbl_logger.log_row(path)
show_schemas("answers")
show_schemas("answers", scale_factor)
return con


def write_metadata(scale_factor: float) -> None:
METADATA_PATH.touch()
logger.info("Writing metadata to: %s", METADATA_PATH.name)
meta = {
"scale_factor": [scale_factor],
"modified_time": [dt.datetime.now(dt.timezone.utc)],
}
pl.DataFrame(meta).write_csv(METADATA_PATH)


def _validate_metadata(metadata: pl.DataFrame) -> tuple[float, dt.datetime]:
meta = metadata.row(0, named=True)
expected_columns = "scale_factor", "modified_time"
if meta.keys() != set(expected_columns):
msg = f"Found unexpected columns in {METADATA_PATH.name!r}.\n"
f"Expected: {expected_columns!r}\nGot: {tuple(meta)!r}"
raise ValueError(msg)
scale_factor = meta["scale_factor"]
modified_time = meta["modified_time"]
if isinstance(scale_factor, float) and isinstance(modified_time, dt.datetime):
logger.info(
"Found existing metadata: scale_factor=%s, modified_time=%s",
scale_factor,
modified_time,
)
return (scale_factor, modified_time)
msg = (
f"Found unexpected data in {METADATA_PATH.name!r}.\n"
f"Expected: ({float.__name__!r}, {dt.datetime.__name__!r})\n"
f"Got: {(type(scale_factor).__name__, type(modified_time).__name__)!r}"
)
raise TypeError(msg)


def try_read_metadata() -> tuple[float, dt.datetime] | None:
logger.info("Trying to read metadata from: %s", METADATA_PATH.name)
if not METADATA_PATH.exists():
logger.info("Did not find existing metadata")
return None
return _validate_metadata(pl.read_csv(METADATA_PATH, try_parse_dates=True))
def scale_factor_exists(scale_factor: float) -> bool:
"""Check if data for a scale factor exists by checking if its directory exists."""
sf_dir = _scale_factor_dir(scale_factor)
return sf_dir.exists()


def main(*, scale_factor: float = 0.1, refresh: bool = False) -> None:
DATA_DIR.mkdir(exist_ok=True)
if refresh:
logger.info("Refreshing data")
elif meta := try_read_metadata():
if meta[0] == scale_factor:
logger.info(
"Existing metadata matches requested scale_factor=%s", scale_factor
)
show_schemas("database")
show_schemas("answers")
logger.info("To regenerate this scale_factor, use `--refresh`")
return
logger.info(
"Existing metadata does not match requested scale_factor=%s", scale_factor
)
logger.info("Refreshing data for scale_factor=%s", scale_factor)
elif scale_factor_exists(scale_factor):
logger.info("Data already exists for scale_factor=%s", scale_factor)
show_schemas("database", scale_factor)
show_schemas("answers", scale_factor)
logger.info("To regenerate this scale_factor, use `--refresh`")
return

con = connect()
load_tpch_extension(con)
generate_tpch_database(con, scale_factor)
write_tpch_database(con)
write_tpch_answers(con)
write_metadata(scale_factor)
write_tpch_database(con, scale_factor)
write_tpch_answers(con, scale_factor)
total = TableLogger.format_size(sum(e.stat().st_size for e in os.scandir(DATA_DIR)))
logger.info("Finished with total file size: %s", total.strip())

Expand Down
Loading
Loading