Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
78d9634
WIP: sf folders
FBruzzesi Jan 28, 2026
8eb968e
merge head and solve conflicts
FBruzzesi Jan 28, 2026
7de63ea
solve conflicts
FBruzzesi Jan 29, 2026
cb889c7
minor adjustments
FBruzzesi Jan 29, 2026
84a87db
rm metadata
FBruzzesi Jan 29, 2026
b7fc4c5
Update `inputs` docstring
FBruzzesi Jan 29, 2026
e17f5e2
rm duplicated scale_factor_exists, add logging in pyest_configure if …
FBruzzesi Jan 30, 2026
1e15604
rm duplicated scale_factor_exists, add logging in pyest_configure if …
FBruzzesi Jan 30, 2026
f60a79b
refactor: Add `constants.SCALE_FACTOR_DEFAULT`
dangotbanned Jan 30, 2026
fa6a553
refactor: Inline imports again
dangotbanned Jan 30, 2026
b3adca8
fix: Skip data generation once we're in a session
dangotbanned Jan 30, 2026
92ef45a
refactor: Simp
dangotbanned Jan 30, 2026
3cfc1ee
refactor: Use default from `addoption`
dangotbanned Jan 30, 2026
c4ad291
refactor: Use `scale_factor` fixture inside `query` fixture
dangotbanned Jan 30, 2026
615d7fa
refactor: Center `generate_data` around a class (`TPCHGen`)
dangotbanned Jan 30, 2026
19d412c
docs: Shrink doc
dangotbanned Jan 30, 2026
6d289c2
docs: Add `ScaleFactor` alias
dangotbanned Jan 30, 2026
c347216
chore: Restrict allowed `scale_factor` values
dangotbanned Jan 30, 2026
40a0dea
Merge branch 'tpch/refactor-cli' into tpch/refactor-cli-choices
dangotbanned Jan 31, 2026
53147b9
fix: Remove unused `request`
dangotbanned Jan 31, 2026
6332fd4
docs: Show all options in `--help`, redo style
dangotbanned Jan 31, 2026
71e5998
fix: Don't allow uninitlized field repr
dangotbanned Jan 31, 2026
bc0afb7
chore: Add debug step for post-`generate_database`
dangotbanned Jan 31, 2026
228a1d5
Merge remote-tracking branch 'upstream/main' into tpch/refactor-cli-c…
dangotbanned Jan 31, 2026
93fbf49
perf: Avoid materializing, fix memory leaks
dangotbanned Jan 31, 2026
f72f7b8
remove date from logs
dangotbanned Jan 31, 2026
8b47e94
fix: Support `scale_factor=30.0` without memory issues
dangotbanned Jan 31, 2026
48cdfdd
chore: Clean up when you're done
dangotbanned Feb 1, 2026
1656d8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2026
6232ec1
Merge remote-tracking branch 'upstream/tpch/refactor-cli' into tpch/r…
dangotbanned Feb 1, 2026
e261f50
ooops, missed a rename
dangotbanned Feb 1, 2026
118a396
Update tpch/generate_data.py
dangotbanned Feb 1, 2026
78d0d2a
Merge branch 'tpch/refactor-cli-choices' of https://github.com/narwha…
dangotbanned Feb 1, 2026
3c23ce8
drive-by tidy
dangotbanned Feb 1, 2026
44ef295
refactor: Drop `TPCHGen._batches`
dangotbanned Feb 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 7 additions & 46 deletions tpch/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from importlib import import_module
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import narwhals as nw
from narwhals.exceptions import NarwhalsError
Expand All @@ -20,23 +20,16 @@
from pathlib import Path

import polars as pl
import pytest
from typing_extensions import Self

from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals.typing import FileSource
from tpch.typing_ import (
KnownImpl,
Predicate,
QueryID,
QueryModule,
TPCHBackend,
XFailRaises,
)
from tpch.typing_ import QueryID, QueryModule, ScaleFactor, TPCHBackend


class Backend:
name: TPCHBackend
implementation: KnownImpl
implementation: Literal[_EagerAllowedImpl, _LazyAllowedImpl]
kwds: dict[str, Any]

def __init__(self, name: TPCHBackend, /, **kwds: Any) -> None:
Expand All @@ -57,15 +50,11 @@ def scan(self, source: FileSource) -> nw.LazyFrame[Any]:
class Query:
id: QueryID
table_names: tuple[str, ...]
scale_factor: float
_into_xfails: tuple[tuple[Predicate, str, XFailRaises], ...]
_into_skips: tuple[tuple[Predicate, str], ...]
scale_factor: ScaleFactor

def __init__(self, query_id: QueryID, table_names: tuple[str, ...]) -> None:
self.id = query_id
self.table_names = table_names
self._into_xfails = ()
self._into_skips = ()
self.scale_factor = SCALE_FACTOR_DEFAULT

def __repr__(self) -> str:
Expand All @@ -85,10 +74,9 @@ def expected(self) -> pl.DataFrame:
sf_dir = _scale_factor_dir(self.scale_factor)
return pl.read_parquet(sf_dir / f"result_{self}.parquet")

def execute(self, backend: Backend, request: pytest.FixtureRequest) -> None:
def execute(self, backend: Backend) -> None:
from polars.testing import assert_frame_equal

self._apply_skips(backend)
data = self.inputs(backend)
query = self._import_module().query

Expand All @@ -97,44 +85,17 @@ def execute(self, backend: Backend, request: pytest.FixtureRequest) -> None:
except NarwhalsError as exc:
msg = f"Query [{self}-{backend}] ({self.scale_factor=}) failed with the following error in Narwhals:\n{exc}"
raise RuntimeError(msg) from exc

self._apply_xfails(backend, request)
expected = self.expected()
try:
assert_frame_equal(expected, result, check_dtypes=False)
except AssertionError as exc:
msg = f"Query [{self}-{backend}] ({self.scale_factor=}) resulted in wrong answer:\n{exc}"
raise AssertionError(msg) from exc

def with_scale_factor(self, scale_factor: float, /) -> Query:
def with_scale_factor(self, scale_factor: ScaleFactor, /) -> Query:
self.scale_factor = scale_factor
return self

def with_skip(self, predicate: Predicate, reason: str) -> Query:
self._into_skips = (*self._into_skips, (predicate, reason))
return self

def with_xfail(
self, predicate: Predicate, reason: str, *, raises: XFailRaises = AssertionError
) -> Query:
self._into_xfails = (*self._into_xfails, (predicate, reason, raises))
return self

def _apply_skips(self, backend: Backend) -> None:
import pytest

for predicate, reason in self._into_skips:
if predicate(backend, self.scale_factor):
pytest.skip(reason)

def _apply_xfails(self, backend: Backend, request: pytest.FixtureRequest) -> None:
import pytest

for predicate, reason, raises in self._into_xfails:
condition = predicate(backend, self.scale_factor)
mark = pytest.mark.xfail(condition, reason=reason, raises=raises)
request.applymarker(mark)

def _import_module(self) -> QueryModule:
result: Any = import_module(f"{QUERIES_PACKAGE}.{self}")
return result
Expand Down
8 changes: 5 additions & 3 deletions tpch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
from pathlib import Path
from typing import TYPE_CHECKING, get_args

from tpch.typing_ import Artifact, QueryID
from tpch.typing_ import Artifact, QueryID, ScaleFactor

if TYPE_CHECKING:
from collections.abc import Mapping

REPO_ROOT = Path(__file__).parent.parent
TPCH_DIR = REPO_ROOT / "tpch"
DATA_DIR = TPCH_DIR / "data"
DB_PATH = DATA_DIR / "narwhals.duckdb"


@cache
def _scale_factor_dir(scale_factor: float) -> Path:
def _scale_factor_dir(scale_factor: ScaleFactor) -> Path:
"""Get the data directory for a specific scale factor."""
sf_dir = DATA_DIR / f"sf{scale_factor}"
sf_dir.mkdir(parents=True, exist_ok=True)
return sf_dir


SCALE_FACTOR_DEFAULT = 0.1
SCALE_FACTORS: tuple[ScaleFactor, ...] = get_args(ScaleFactor)
SCALE_FACTOR_DEFAULT: ScaleFactor = "0.1"
DATABASE_TABLE_NAMES = (
"lineitem",
"customer",
Expand Down
117 changes: 87 additions & 30 deletions tpch/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,68 @@
import os
import sys
from functools import cache
from typing import TYPE_CHECKING, Any
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from tpch.classes import TableLogger
from tpch.constants import (
DATABASE_TABLE_NAMES,
DB_PATH,
GLOBS,
LOGGER_NAME,
QUERY_IDS,
SCALE_FACTOR_DEFAULT,
SCALE_FACTORS,
_scale_factor_dir,
)

if TYPE_CHECKING:
from collections.abc import Callable, Iterator, Mapping
from pathlib import Path

import polars as pl
import pytest
from duckdb import DuckDBPyConnection as Con, DuckDBPyRelation as Rel
from typing_extensions import LiteralString

from tpch.typing_ import Artifact, QueryID
from tpch.typing_ import Artifact, QueryID, ScaleFactor


logger = logging.getLogger(LOGGER_NAME)


# `mem_usage_scale = 2.705`
# `pl.Config(tbl_hide_column_data_types=True, tbl_hide_dataframe_shape=True)`
# https://duckdb.org/docs/stable/core_extensions/tpch#resource-usage-of-the-data-generator
TABLE_SCALE_FACTOR = """
┌──────────────┬───────────────┐
│ Scale factor ┆ Database (MB) │
╞══════════════╪═══════════════╡
│ 0.1 ┆ 25 │
│ 1.0 ┆ 250 │
│ 3.0 ┆ 754 │
│ 100.0 ┆ 26624 │
└──────────────┴───────────────┘
┌───────┬────────────┬─────────────┐
│ sf ┆ Disk ┆ Memory (db) │
╞═══════╪════════════╪═════════════╡
│ 0.014 ┆ 3.25 mb ┆ 8.79 mb │
│ 0.052 ┆ 12.01 mb ┆ 32.49 mb │
│ 0.1 ┆ 23.15 mb ┆ 62.62 mb │
│ 0.25 ┆ 58.90 mb ┆ 159.32 mb │
│ 0.51 ┆ 124.40 mb ┆ 336.50 mb │
│ 1.0 ┆ 247.66 mb ┆ 669.92 mb │
│ 10.0 ┆ 2.59 gb ┆ 7.00 gb │
│ 30.0 ┆ 7.76 gb ┆ 21.00 gb │
└───────┴────────────┴─────────────┘
"""


# NOTE: Store queries here, add parameter names if needed
SQL_DBGEN = "CALL dbgen(sf={0})"
SQL_DBGEN = "CALL dbgen(sf=$sf)"
SQL_DBGEN_STEPPED = "CALL dbgen(sf=$sf, children=$children, step=$step)"
SQL_TPCH_ANSWER = "PRAGMA tpch({0})"
SQL_FROM = "FROM {0}"
SQL_SHOW_DB = """
SELECT
"table": name,
"schema": MAP(column_names, column_types)
FROM
(SHOW ALL TABLES)
"""

FIX_ANSWERS: Mapping[QueryID, Callable[[pl.DataFrame], pl.DataFrame]] = {
FIX_ANSWERS: Mapping[QueryID, Callable[[pl.LazyFrame], pl.LazyFrame]] = {
"q18": lambda df: df.rename({"sum(l_quantity)": "sum"}).cast({"sum": int}),
"q22": lambda df: df.cast({"cntrycode": int}),
}
Expand Down Expand Up @@ -82,10 +101,10 @@ def cast_map() -> dict[Any, Any]:

@dataclasses.dataclass(**({"kw_only": True} if sys.version_info >= (3, 10) else {}))
class TPCHGen:
scale_factor: float
scale_factor: ScaleFactor
refresh: bool = False
debug: bool = False
_con: Con = dataclasses.field(init=False)
_con: Con = dataclasses.field(init=False, repr=False)

@staticmethod
def from_pytest(config: pytest.Config, /) -> TPCHGen:
Expand All @@ -99,6 +118,16 @@ def from_argparse(parser: argparse.ArgumentParser, /) -> TPCHGen:
def scale_factor_dir(self) -> Path:
return _scale_factor_dir(self.scale_factor)

@property
def _batches(self) -> int:
if self.scale_factor in {"10.0", "30.0"}:
return 12 if self.scale_factor == "10.0" else 4
return 0

@property
def _database(self) -> Path | Literal[":memory:"]:
return DB_PATH if self.scale_factor == "30.0" else ":memory:"

def glob(self, artifact: Artifact, /) -> Iterator[Path]:
return self.scale_factor_dir.glob(GLOBS[artifact])

Expand All @@ -115,40 +144,64 @@ def run(self) -> None:
self.show_schemas("database").show_schemas("answers")
logger.info("To regenerate this scale_factor, use `--refresh`")
return
self.connect().load_extension().generate_database().write_database().write_answers()
self.connect().load_extension().generate_database().write_database().write_answers().disconnect()
n_bytes = sum(e.stat().st_size for e in os.scandir(self.scale_factor_dir))
total = TableLogger.format_size(n_bytes)
logger.info("Finished with total file size: %s", total.strip())

def connect(self) -> TPCHGen:
import duckdb

logger.info("Connecting to in-memory DuckDB database")
self._con = duckdb.connect(database=":memory:")
database = self._database
name = database.as_posix() if isinstance(database, Path) else database
logger.info("Connecting to DuckDB database: %s", name)
self._con = duckdb.connect(database)
return self

# TODO @dangotbanned: Change to `LiteralString` after restricting `--scale-factor`
def sql(self, query: str) -> Rel:
return self._con.sql(query)
def disconnect(self) -> TPCHGen:
self._con.close()
if isinstance(self._database, Path):
logger.info("Dropping DuckDB database: %s", self._database.as_posix())
self._database.unlink(missing_ok=True)
return self

def sql(self, query: LiteralString, **params: LiteralString | int) -> Rel:
return self._con.sql(query, params=params or None)

def load_extension(self) -> TPCHGen:
logger.info("Installing DuckDB TPC-H Extension")
self._con.install_extension("tpch")
self._con.load_extension("tpch")
return self

def _generate_database_batched(self, batches: int) -> TPCHGen:
logger.info("Whelp, this may take a while...")
logger.info("Generating in %s batches", batches)
for batch in range(batches):
self.sql(
SQL_DBGEN_STEPPED, sf=self.scale_factor, children=batches, step=batch
)
logger.info("Generated (%s/%s)", batch + 1, batches)
return self

def generate_database(self) -> TPCHGen:
logger.info("Generating data for scale_factor=%s", self.scale_factor)
self.sql(SQL_DBGEN.format(self.scale_factor))
if batches := self._batches:
self._generate_database_batched(batches)
else:
self.sql(SQL_DBGEN, sf=self.scale_factor)
logger.info("Finished generating data.")
if logger.isEnabledFor(logging.DEBUG):
msg = str(self.sql(SQL_SHOW_DB))[:-1]
logger.debug("DuckDB schemas (database):\n%s", msg)
return self

def write_database(self) -> TPCHGen:
logger.info("Writing data to: %s", self.scale_factor_dir.as_posix())
with TableLogger.database() as tbl_logger:
for t in DATABASE_TABLE_NAMES:
path = self.scale_factor_dir / f"{t}.parquet"
self.sql(SQL_FROM.format(t)).pl().cast(cast_map()).write_parquet(path)
to_polars(self.sql(SQL_FROM.format(t))).sink_parquet(path)
tbl_logger.log_row(path)
return self.show_schemas("database")

Expand All @@ -157,30 +210,34 @@ def write_answers(self) -> TPCHGen:
with TableLogger.answers() as tbl_logger:
for query_id in QUERY_IDS:
query = SQL_TPCH_ANSWER.format(query_id.removeprefix("q"))
df = self.sql(query).pl().cast(cast_map())
lf = to_polars(self.sql(query))
if fix := FIX_ANSWERS.get(query_id):
df = fix(df)
lf = fix(lf)
path = self.scale_factor_dir / f"result_{query_id}.parquet"
df.write_parquet(path)
lf.sink_parquet(path)
tbl_logger.log_row(path)
return self.show_schemas("answers")

def show_schemas(self, artifact: Artifact, /) -> TPCHGen:
if logger.isEnabledFor(logging.DEBUG):
if paths := sorted(self.glob(artifact)):
msg = "\n".join(read_fmt_schema(fp) for fp in paths)
logger.debug("Schemas (%s):\n%s", artifact, msg)
logger.debug("Parquet schemas (%s):\n%s", artifact, msg)
else:
msg = f"Found no matching paths for {artifact!r} in {self.scale_factor_dir.as_posix()}"
raise NotImplementedError(msg)
return self


def to_polars(rel: Rel) -> pl.LazyFrame:
return rel.pl(lazy=True).cast(cast_map())


def _configure_logger(
*,
debug: bool,
fmt: str = "%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s",
datefmt: str = "%Y-%m-%d %H:%M:%S",
datefmt: str = "%H:%M:%S",
) -> None:
logger.setLevel(logging.DEBUG if debug else logging.INFO)
output = logging.StreamHandler()
Expand All @@ -204,10 +261,10 @@ def start_section(self, heading: str | None) -> None:
parser.add_argument(
"-sf",
"--scale-factor",
default=str(SCALE_FACTOR_DEFAULT),
default=SCALE_FACTOR_DEFAULT,
metavar="",
help=f"Scale the database by this factor (default: %(default)s)\n{TABLE_SCALE_FACTOR}",
type=float,
choices=SCALE_FACTORS,
)
parser.add_argument(
"--debug", action="store_true", help="Enable more detailed logging"
Expand Down
Loading
Loading