diff --git a/tpch/classes.py b/tpch/classes.py index 927e816aa7..d7939d8e2c 100644 --- a/tpch/classes.py +++ b/tpch/classes.py @@ -2,7 +2,7 @@ import logging from importlib import import_module -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import narwhals as nw from narwhals.exceptions import NarwhalsError @@ -20,23 +20,16 @@ from pathlib import Path import polars as pl - import pytest from typing_extensions import Self + from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl from narwhals.typing import FileSource - from tpch.typing_ import ( - KnownImpl, - Predicate, - QueryID, - QueryModule, - TPCHBackend, - XFailRaises, - ) + from tpch.typing_ import QueryID, QueryModule, ScaleFactor, TPCHBackend class Backend: name: TPCHBackend - implementation: KnownImpl + implementation: Literal[_EagerAllowedImpl, _LazyAllowedImpl] kwds: dict[str, Any] def __init__(self, name: TPCHBackend, /, **kwds: Any) -> None: @@ -57,15 +50,11 @@ def scan(self, source: FileSource) -> nw.LazyFrame[Any]: class Query: id: QueryID table_names: tuple[str, ...] - scale_factor: float - _into_xfails: tuple[tuple[Predicate, str, XFailRaises], ...] - _into_skips: tuple[tuple[Predicate, str], ...] + scale_factor: ScaleFactor def __init__(self, query_id: QueryID, table_names: tuple[str, ...]) -> None: self.id = query_id self.table_names = table_names - self._into_xfails = () - self._into_skips = () self.scale_factor = SCALE_FACTOR_DEFAULT def __repr__(self) -> str: @@ -85,10 +74,9 @@ def expected(self) -> pl.DataFrame: sf_dir = _scale_factor_dir(self.scale_factor) return pl.read_parquet(sf_dir / f"result_{self}.parquet") - def execute(self, backend: Backend, request: pytest.FixtureRequest) -> None: + def execute(self, backend: Backend) -> None: from polars.testing import assert_frame_equal - self._apply_skips(backend) data = self.inputs(backend) query = self._import_module().query @@ -97,8 +85,6 @@ def execute(self, backend: Backend, request: pytest.FixtureRequest) -> None: except NarwhalsError as exc: msg = f"Query [{self}-{backend}] ({self.scale_factor=}) failed with the following error in Narwhals:\n{exc}" raise RuntimeError(msg) from exc - - self._apply_xfails(backend, request) expected = self.expected() try: assert_frame_equal(expected, result, check_dtypes=False) @@ -106,35 +92,10 @@ def execute(self, backend: Backend, request: pytest.FixtureRequest) -> None: msg = f"Query [{self}-{backend}] ({self.scale_factor=}) resulted in wrong answer:\n{exc}" raise AssertionError(msg) from exc - def with_scale_factor(self, scale_factor: float, /) -> Query: + def with_scale_factor(self, scale_factor: ScaleFactor, /) -> Query: self.scale_factor = scale_factor return self - def with_skip(self, predicate: Predicate, reason: str) -> Query: - self._into_skips = (*self._into_skips, (predicate, reason)) - return self - - def with_xfail( - self, predicate: Predicate, reason: str, *, raises: XFailRaises = AssertionError - ) -> Query: - self._into_xfails = (*self._into_xfails, (predicate, reason, raises)) - return self - - def _apply_skips(self, backend: Backend) -> None: - import pytest - - for predicate, reason in self._into_skips: - if predicate(backend, self.scale_factor): - pytest.skip(reason) - - def _apply_xfails(self, backend: Backend, request: pytest.FixtureRequest) -> None: - import pytest - - for predicate, reason, raises in self._into_xfails: - condition = predicate(backend, self.scale_factor) - mark = pytest.mark.xfail(condition, reason=reason, raises=raises) - request.applymarker(mark) - def _import_module(self) -> QueryModule: result: Any = import_module(f"{QUERIES_PACKAGE}.{self}") return result diff --git a/tpch/constants.py b/tpch/constants.py index 4e5580291e..644b3468b7 100644 --- a/tpch/constants.py +++ b/tpch/constants.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, get_args -from tpch.typing_ import Artifact, QueryID +from tpch.typing_ import Artifact, QueryID, ScaleFactor if TYPE_CHECKING: from collections.abc import Mapping @@ -12,17 +12,19 @@ REPO_ROOT = Path(__file__).parent.parent TPCH_DIR = REPO_ROOT / "tpch" DATA_DIR = TPCH_DIR / "data" +DB_PATH = DATA_DIR / "narwhals.duckdb" @cache -def _scale_factor_dir(scale_factor: float) -> Path: +def _scale_factor_dir(scale_factor: ScaleFactor) -> Path: """Get the data directory for a specific scale factor.""" sf_dir = DATA_DIR / f"sf{scale_factor}" sf_dir.mkdir(parents=True, exist_ok=True) return sf_dir -SCALE_FACTOR_DEFAULT = 0.1 +SCALE_FACTORS: tuple[ScaleFactor, ...] = get_args(ScaleFactor) +SCALE_FACTOR_DEFAULT: ScaleFactor = "0.1" DATABASE_TABLE_NAMES = ( "lineitem", "customer", diff --git a/tpch/generate_data.py b/tpch/generate_data.py index 197891e429..dbae28205e 100644 --- a/tpch/generate_data.py +++ b/tpch/generate_data.py @@ -11,49 +11,64 @@ import os import sys from functools import cache -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal from tpch.classes import TableLogger from tpch.constants import ( DATABASE_TABLE_NAMES, + DB_PATH, GLOBS, LOGGER_NAME, QUERY_IDS, SCALE_FACTOR_DEFAULT, + SCALE_FACTORS, _scale_factor_dir, ) if TYPE_CHECKING: from collections.abc import Callable, Iterator, Mapping - from pathlib import Path import polars as pl import pytest from duckdb import DuckDBPyConnection as Con, DuckDBPyRelation as Rel + from typing_extensions import LiteralString - from tpch.typing_ import Artifact, QueryID + from tpch.typing_ import Artifact, QueryID, ScaleFactor logger = logging.getLogger(LOGGER_NAME) TABLE_SCALE_FACTOR = """ -┌──────────────┬───────────────┐ -│ Scale factor ┆ Database (MB) │ -╞══════════════╪═══════════════╡ -│ 0.1 ┆ 25 │ -│ 1.0 ┆ 250 │ -│ 3.0 ┆ 754 │ -│ 100.0 ┆ 26624 │ -└──────────────┴───────────────┘ +┌───────┬────────────┬─────────────┐ +│ sf ┆ Disk ┆ Memory (db) │ +╞═══════╪════════════╪═════════════╡ +│ 0.014 ┆ 3.25 mb ┆ 8.79 mb │ +│ 0.052 ┆ 12.01 mb ┆ 32.49 mb │ +│ 0.1 ┆ 23.15 mb ┆ 62.62 mb │ +│ 0.25 ┆ 58.90 mb ┆ 159.32 mb │ +│ 0.51 ┆ 124.40 mb ┆ 336.50 mb │ +│ 1.0 ┆ 247.66 mb ┆ 669.92 mb │ +│ 10.0 ┆ 2.59 gb ┆ 7.00 gb │ +│ 30.0 ┆ 7.76 gb ┆ 21.00 gb │ +└───────┴────────────┴─────────────┘ """ # NOTE: Store queries here, add parameter names if needed -SQL_DBGEN = "CALL dbgen(sf={0})" +SQL_DBGEN = "CALL dbgen(sf=$sf)" +SQL_DBGEN_BATCHED = "CALL dbgen(sf=$sf, children=$children, step=$step)" SQL_TPCH_ANSWER = "PRAGMA tpch({0})" SQL_FROM = "FROM {0}" +SQL_SHOW_DB = """ +SELECT + "table": name, + "schema": MAP(column_names, column_types) +FROM + (SHOW ALL TABLES) +""" -FIX_ANSWERS: Mapping[QueryID, Callable[[pl.DataFrame], pl.DataFrame]] = { +FIX_ANSWERS: Mapping[QueryID, Callable[[pl.LazyFrame], pl.LazyFrame]] = { "q18": lambda df: df.rename({"sum(l_quantity)": "sum"}).cast({"sum": int}), "q22": lambda df: df.cast({"cntrycode": int}), } @@ -82,10 +97,10 @@ def cast_map() -> dict[Any, Any]: @dataclasses.dataclass(**({"kw_only": True} if sys.version_info >= (3, 10) else {})) class TPCHGen: - scale_factor: float + scale_factor: ScaleFactor refresh: bool = False debug: bool = False - _con: Con = dataclasses.field(init=False) + _con: Con = dataclasses.field(init=False, repr=False) @staticmethod def from_pytest(config: pytest.Config, /) -> TPCHGen: @@ -99,6 +114,10 @@ def from_argparse(parser: argparse.ArgumentParser, /) -> TPCHGen: def scale_factor_dir(self) -> Path: return _scale_factor_dir(self.scale_factor) + @property + def _database(self) -> Path | Literal[":memory:"]: + return DB_PATH if self.scale_factor == "30.0" else ":memory:" + def glob(self, artifact: Artifact, /) -> Iterator[Path]: return self.scale_factor_dir.glob(GLOBS[artifact]) @@ -115,7 +134,7 @@ def run(self) -> None: self.show_schemas("database").show_schemas("answers") logger.info("To regenerate this scale_factor, use `--refresh`") return - self.connect().load_extension().generate_database().write_database().write_answers() + self.connect().load_extension().generate_database().write_database().write_answers().disconnect() n_bytes = sum(e.stat().st_size for e in os.scandir(self.scale_factor_dir)) total = TableLogger.format_size(n_bytes) logger.info("Finished with total file size: %s", total.strip()) @@ -123,13 +142,21 @@ def run(self) -> None: def connect(self) -> TPCHGen: import duckdb - logger.info("Connecting to in-memory DuckDB database") - self._con = duckdb.connect(database=":memory:") + database = self._database + name = database.as_posix() if isinstance(database, Path) else database + logger.info("Connecting to DuckDB database: %s", name) + self._con = duckdb.connect(database) + return self + + def disconnect(self) -> TPCHGen: + self._con.close() + if isinstance(self._database, Path): + logger.info("Dropping DuckDB database: %s", self._database.as_posix()) + self._database.unlink(missing_ok=True) return self - # TODO @dangotbanned: Change to `LiteralString` after restricting `--scale-factor` - def sql(self, query: str) -> Rel: - return self._con.sql(query) + def sql(self, query: LiteralString, **params: LiteralString | int) -> Rel: + return self._con.sql(query, params=params or None) def load_extension(self) -> TPCHGen: logger.info("Installing DuckDB TPC-H Extension") @@ -137,10 +164,27 @@ def load_extension(self) -> TPCHGen: self._con.load_extension("tpch") return self + def _generate_database_batched(self, batches: int) -> TPCHGen: + logger.info("Whelp, this may take a while...") + logger.info("Generating in %s batches", batches) + for batch in range(batches): + self.sql( + SQL_DBGEN_BATCHED, sf=self.scale_factor, children=batches, step=batch + ) + logger.info("Generated (%s/%s)", batch + 1, batches) + return self + def generate_database(self) -> TPCHGen: - logger.info("Generating data for scale_factor=%s", self.scale_factor) - self.sql(SQL_DBGEN.format(self.scale_factor)) + sf = self.scale_factor + logger.info("Generating data for scale_factor=%s", sf) + if sf in {"10.0", "30.0"}: + self._generate_database_batched(12 if sf == "10.0" else 4) + else: + self.sql(SQL_DBGEN, sf=sf) logger.info("Finished generating data.") + if logger.isEnabledFor(logging.DEBUG): + msg = str(self.sql(SQL_SHOW_DB))[:-1] + logger.debug("DuckDB schemas (database):\n%s", msg) return self def write_database(self) -> TPCHGen: @@ -148,7 +192,7 @@ def write_database(self) -> TPCHGen: with TableLogger.database() as tbl_logger: for t in DATABASE_TABLE_NAMES: path = self.scale_factor_dir / f"{t}.parquet" - self.sql(SQL_FROM.format(t)).pl().cast(cast_map()).write_parquet(path) + to_polars(self.sql(SQL_FROM.format(t))).sink_parquet(path) tbl_logger.log_row(path) return self.show_schemas("database") @@ -157,11 +201,11 @@ def write_answers(self) -> TPCHGen: with TableLogger.answers() as tbl_logger: for query_id in QUERY_IDS: query = SQL_TPCH_ANSWER.format(query_id.removeprefix("q")) - df = self.sql(query).pl().cast(cast_map()) + lf = to_polars(self.sql(query)) if fix := FIX_ANSWERS.get(query_id): - df = fix(df) + lf = fix(lf) path = self.scale_factor_dir / f"result_{query_id}.parquet" - df.write_parquet(path) + lf.sink_parquet(path) tbl_logger.log_row(path) return self.show_schemas("answers") @@ -169,18 +213,22 @@ def show_schemas(self, artifact: Artifact, /) -> TPCHGen: if logger.isEnabledFor(logging.DEBUG): if paths := sorted(self.glob(artifact)): msg = "\n".join(read_fmt_schema(fp) for fp in paths) - logger.debug("Schemas (%s):\n%s", artifact, msg) + logger.debug("Parquet schemas (%s):\n%s", artifact, msg) else: msg = f"Found no matching paths for {artifact!r} in {self.scale_factor_dir.as_posix()}" raise NotImplementedError(msg) return self +def to_polars(rel: Rel) -> pl.LazyFrame: + return rel.pl(lazy=True).cast(cast_map()) + + def _configure_logger( *, debug: bool, fmt: str = "%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s", - datefmt: str = "%Y-%m-%d %H:%M:%S", + datefmt: str = "%H:%M:%S", ) -> None: logger.setLevel(logging.DEBUG if debug else logging.INFO) output = logging.StreamHandler() @@ -204,10 +252,10 @@ def start_section(self, heading: str | None) -> None: parser.add_argument( "-sf", "--scale-factor", - default=str(SCALE_FACTOR_DEFAULT), + default=SCALE_FACTOR_DEFAULT, metavar="", help=f"Scale the database by this factor (default: %(default)s)\n{TABLE_SCALE_FACTOR}", - type=float, + choices=SCALE_FACTORS, ) parser.add_argument( "--debug", action="store_true", help="Enable more detailed logging" diff --git a/tpch/tests/conftest.py b/tpch/tests/conftest.py index 7c34cc9ee4..9b04aed35e 100644 --- a/tpch/tests/conftest.py +++ b/tpch/tests/conftest.py @@ -7,12 +7,12 @@ import pytest from tpch.classes import Backend, Query -from tpch.constants import SCALE_FACTOR_DEFAULT +from tpch.constants import SCALE_FACTOR_DEFAULT, SCALE_FACTORS if TYPE_CHECKING: from collections.abc import Iterator - from tpch.typing_ import QueryID + from tpch.typing_ import QueryID, ScaleFactor # Table names used to construct paths dynamically TBL_LINEITEM = "lineitem" @@ -24,43 +24,6 @@ TBL_ORDERS = "orders" TBL_CUSTOMER = "customer" -SCALE_FACTORS_BLESSED = frozenset( - (1.0, 10.0, 30.0, 100.0, 300.0, 1_000.0, 3_000.0, 10_000.0, 30_000.0, 100_000.0) -) -"""`scale_factor` values that are listed on [TPC-H v3.0.1 (Page 79)]. - -Using any other value *can* lead to incorrect results. - -[TPC-H_v3.0.1 (Page 79)]: https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-H_v3.0.1.pdf -""" - -SCALE_FACTORS_QUITE_SAFE = frozenset( - ( - 0.014, - 0.02, - 0.029, - 0.04, - 0.052, - 0.06, - 0.072, - 0.081, - 0.091, - 0.1, - 0.13, - 0.23, - 0.25, - 0.275, - 0.29, - 0.3, - 0.43, - 0.51, - ) -) -"""scale_factor` values that are **lower** than [TPC-H v3.0.1 (Page 79)], but still work fine. - -[TPC-H_v3.0.1 (Page 79)]: https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-H_v3.0.1.pdf -""" - def is_xdist_worker(obj: pytest.FixtureRequest | pytest.Config, /) -> bool: # Adapted from https://github.com/pytest-dev/pytest-xdist/blob/8b60b1ef5d48974a1cb69bc1a9843564bdc06498/src/xdist/plugin.py#L337-L349 @@ -95,8 +58,8 @@ def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( "--scale-factor", action="store", - default=str(SCALE_FACTOR_DEFAULT), - type=float, + default=SCALE_FACTOR_DEFAULT, + choices=SCALE_FACTORS, help="TPC-H scale factor to use for tests (default: %(default)s)", ) @@ -136,7 +99,6 @@ def q(query_id: QueryID, *table_names: str) -> Query: def iter_queries() -> Iterator[Query]: - safe = SCALE_FACTORS_BLESSED | SCALE_FACTORS_QUITE_SAFE yield from ( q("q1", TBL_LINEITEM), q("q2", TBL_REGION, TBL_NATION, TBL_SUPPLIER, TBL_PART, TBL_PARTSUPP), @@ -173,41 +135,28 @@ def iter_queries() -> Iterator[Query]: TBL_SUPPLIER, ), q("q10", TBL_CUSTOMER, TBL_NATION, TBL_LINEITEM, TBL_ORDERS), - q("q11", TBL_NATION, TBL_PARTSUPP, TBL_SUPPLIER).with_skip( - lambda _, scale_factor: scale_factor not in safe, - reason="https://github.com/duckdb/duckdb/issues/17965", - ), + q("q11", TBL_NATION, TBL_PARTSUPP, TBL_SUPPLIER), q("q12", TBL_LINEITEM, TBL_ORDERS), q("q13", TBL_CUSTOMER, TBL_ORDERS), q("q14", TBL_LINEITEM, TBL_PART), q("q15", TBL_LINEITEM, TBL_SUPPLIER), q("q16", TBL_PART, TBL_PARTSUPP, TBL_SUPPLIER), - q("q17", TBL_LINEITEM, TBL_PART) - .with_xfail( - lambda _, scale_factor: (scale_factor < 0.014), - reason="Generated dataset is too small, leading to 0 rows after the first two filters in `query1`.", - ) - .with_skip( - lambda _, scale_factor: scale_factor not in safe, - reason="Non-deterministic fails for `duckdb`, `sqlframe`. All other always fail, except `pyarrow` which always passes 🤯.", - ), + q("q17", TBL_LINEITEM, TBL_PART), q("q18", TBL_CUSTOMER, TBL_LINEITEM, TBL_ORDERS), q("q19", TBL_LINEITEM, TBL_PART), q("q20", TBL_PART, TBL_PARTSUPP, TBL_NATION, TBL_LINEITEM, TBL_SUPPLIER), - q("q21", TBL_LINEITEM, TBL_NATION, TBL_ORDERS, TBL_SUPPLIER).with_skip( - lambda _, scale_factor: scale_factor not in safe, reason="Off-by-1 error" - ), + q("q21", TBL_LINEITEM, TBL_NATION, TBL_ORDERS, TBL_SUPPLIER), q("q22", TBL_CUSTOMER, TBL_ORDERS), ) @pytest.fixture(scope="session") -def scale_factor(request: pytest.FixtureRequest) -> float: +def scale_factor(request: pytest.FixtureRequest) -> ScaleFactor: """Get the scale factor from pytest options.""" - return float(request.config.getoption("--scale-factor")) + return request.config.getoption("--scale-factor") @pytest.fixture(params=iter_queries(), ids=repr) -def query(request: pytest.FixtureRequest, scale_factor: float) -> Query: +def query(request: pytest.FixtureRequest, scale_factor: ScaleFactor) -> Query: result: Query = request.param return result.with_scale_factor(scale_factor) diff --git a/tpch/tests/queries_test.py b/tpch/tests/queries_test.py index e698a5bd17..3b1ab7536a 100644 --- a/tpch/tests/queries_test.py +++ b/tpch/tests/queries_test.py @@ -3,13 +3,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - import pytest - from tpch.classes import Backend, Query -def test_execute_query( - query: Query, backend: Backend, request: pytest.FixtureRequest -) -> None: +def test_execute_query(query: Query, backend: Backend) -> None: """Helper function to run a TPCH query test.""" - query.execute(backend, request) + query.execute(backend) diff --git a/tpch/typing_.py b/tpch/typing_.py index b6c3010455..5e69c9ec56 100644 --- a/tpch/typing_.py +++ b/tpch/typing_.py @@ -2,16 +2,12 @@ from typing import TYPE_CHECKING, Any, Literal, Protocol -from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl - if TYPE_CHECKING: from typing_extensions import TypeAlias import narwhals as nw - from tpch.classes import Backend -KnownImpl: TypeAlias = Literal[_EagerAllowedImpl, _LazyAllowedImpl] TPCHBackend: TypeAlias = Literal[ "polars[lazy]", "pyarrow", "pandas[pyarrow]", "dask", "duckdb", "sqlframe" ] @@ -39,27 +35,30 @@ "q21", "q22", ] -XFailRaises: TypeAlias = type[BaseException] | tuple[type[BaseException], ...] -Artifact: TypeAlias = Literal["database", "answers"] +ScaleFactor: TypeAlias = Literal[ + "0.014", "0.052", "0.1", "0.25", "0.51", "1.0", "10.0", "30.0" +] +"""Values for `scale_factor` that are known to produce correct results. +These three are blessed by [TPC-H v3.0.1 (Page 79)]: -class QueryModule(Protocol): - def query( - self, *args: nw.LazyFrame[Any], **kwds: nw.LazyFrame[Any] - ) -> nw.LazyFrame[Any]: ... + "1.0", "10.0", "30.0" +These five are *not*, but represent a [benchmark runtime] between 13-72 seconds: -class Predicate(Protocol): - """Failure-state-context callback. + "0.014", "0.052", "0.1", "0.25", "0.51" - The returned value will be used in either: +Warning: + Running the higher values can **easily** crash when combined with [`pytest-xdist`]. + We are effectively running `scale_factor * 6` when all backends are selected. - pytest.mark.xfail(predicate(backend, scale_factor)) +[TPC-H v3.0.1 (Page 79)]: https://www.tpc.org/TPC_Documents_Current_Versions/pdf/TPC-H_v3.0.1.pdf +[benchmark runtime]: https://github.com/narwhals-dev/narwhals/pull/3421#discussion_r2743356336 +[`pytest-xdist`]: https://pytest-xdist.readthedocs.io/en/stable/ +""" - Or: +Artifact: TypeAlias = Literal["database", "answers"] - if predicate(backend, scale_factor): - pytest.skip() - """ - def __call__(self, backend: Backend, scale_factor: float, /) -> bool: ... +class QueryModule(Protocol): + def query(self, *args: nw.LazyFrame[Any]) -> nw.LazyFrame[Any]: ...