Skip to content

Commit 5cf8066

Browse files
chore(tpch): Create one folder for each scale_factor (#3427)
Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com>
1 parent dd21fa0 commit 5cf8066

File tree

5 files changed

+271
-240
lines changed

5 files changed

+271
-240
lines changed

tpch/classes.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,23 @@
44
from importlib import import_module
55
from typing import TYPE_CHECKING, Any
66

7-
import polars as pl
8-
import pytest
9-
from polars.testing import assert_frame_equal as pl_assert_frame_equal
10-
117
import narwhals as nw
128
from narwhals.exceptions import NarwhalsError
13-
from tpch import constants
9+
from tpch.constants import (
10+
DATABASE_TABLE_NAMES,
11+
LOGGER_NAME,
12+
QUERIES_PACKAGE,
13+
QUERY_IDS,
14+
SCALE_FACTOR_DEFAULT,
15+
_scale_factor_dir,
16+
)
1417

1518
if TYPE_CHECKING:
1619
from collections.abc import Iterable
1720
from pathlib import Path
1821

22+
import polars as pl
23+
import pytest
1924
from typing_extensions import Self
2025

2126
from narwhals.typing import FileSource
@@ -51,38 +56,60 @@ def scan(self, source: FileSource) -> nw.LazyFrame[Any]:
5156

5257
class Query:
5358
id: QueryID
54-
paths: tuple[Path, ...]
59+
table_names: tuple[str, ...]
60+
scale_factor: float
5561
_into_xfails: tuple[tuple[Predicate, str, XFailRaises], ...]
5662
_into_skips: tuple[tuple[Predicate, str], ...]
5763

58-
def __init__(self, query_id: QueryID, paths: tuple[Path, ...]) -> None:
64+
def __init__(self, query_id: QueryID, table_names: tuple[str, ...]) -> None:
5965
self.id = query_id
60-
self.paths = paths
66+
self.table_names = table_names
6167
self._into_xfails = ()
6268
self._into_skips = ()
69+
self.scale_factor = SCALE_FACTOR_DEFAULT
6370

6471
def __repr__(self) -> str:
6572
return self.id
6673

67-
def execute(
68-
self, backend: Backend, scale_factor: float, request: pytest.FixtureRequest
69-
) -> None:
70-
self._apply_skips(backend, scale_factor)
71-
data = tuple(backend.scan(fp.as_posix()) for fp in self.paths)
74+
def inputs(self, backend: Backend) -> tuple[nw.LazyFrame[Any], ...]:
75+
"""Get the frame inputs for this query at the given scale factor."""
76+
sf_dir = _scale_factor_dir(self.scale_factor)
77+
return tuple(
78+
backend.scan((sf_dir / f"{name}.parquet").as_posix())
79+
for name in self.table_names
80+
)
81+
82+
def expected(self) -> pl.DataFrame:
83+
import polars as pl
84+
85+
sf_dir = _scale_factor_dir(self.scale_factor)
86+
return pl.read_parquet(sf_dir / f"result_{self}.parquet")
87+
88+
def execute(self, backend: Backend, request: pytest.FixtureRequest) -> None:
89+
from polars.testing import assert_frame_equal
90+
91+
self._apply_skips(backend)
92+
data = self.inputs(backend)
7293
query = self._import_module().query
94+
7395
try:
7496
result = query(*data).lazy().collect("polars").to_polars()
7597
except NarwhalsError as exc:
76-
msg = f"Query [{self}-{backend}] ({scale_factor=}) failed with the following error in Narwhals:\n{exc}"
98+
msg = f"Query [{self}-{backend}] ({self.scale_factor=}) failed with the following error in Narwhals:\n{exc}"
7799
raise RuntimeError(msg) from exc
78-
self._apply_xfails(backend, scale_factor, request)
79-
expected = pl.read_parquet(constants.DATA_DIR / f"result_{self}.parquet")
100+
101+
self._apply_xfails(backend, request)
102+
expected = self.expected()
80103
try:
81-
pl_assert_frame_equal(expected, result, check_dtypes=False)
104+
assert_frame_equal(expected, result, check_dtypes=False)
82105
except AssertionError as exc:
83-
msg = f"Query [{self}-{backend}] ({scale_factor=}) resulted in wrong answer:\n{exc}"
106+
msg = f"Query [{self}-{backend}] ({self.scale_factor=}) resulted in wrong answer:\n{exc}"
84107
raise AssertionError(msg) from exc
85108

109+
def with_scale_factor(self, scale_factor: float, /) -> Query:
110+
self.scale_factor = scale_factor
111+
return self
112+
86113
def with_skip(self, predicate: Predicate, reason: str) -> Query:
87114
self._into_skips = (*self._into_skips, (predicate, reason))
88115
return self
@@ -93,25 +120,27 @@ def with_xfail(
93120
self._into_xfails = (*self._into_xfails, (predicate, reason, raises))
94121
return self
95122

96-
def _apply_skips(self, backend: Backend, scale_factor: float) -> None:
123+
def _apply_skips(self, backend: Backend) -> None:
124+
import pytest
125+
97126
for predicate, reason in self._into_skips:
98-
if predicate(backend, scale_factor):
127+
if predicate(backend, self.scale_factor):
99128
pytest.skip(reason)
100129

101-
def _apply_xfails(
102-
self, backend: Backend, scale_factor: float, request: pytest.FixtureRequest
103-
) -> None:
130+
def _apply_xfails(self, backend: Backend, request: pytest.FixtureRequest) -> None:
131+
import pytest
132+
104133
for predicate, reason, raises in self._into_xfails:
105-
condition = predicate(backend, scale_factor)
134+
condition = predicate(backend, self.scale_factor)
106135
mark = pytest.mark.xfail(condition, reason=reason, raises=raises)
107136
request.applymarker(mark)
108137

109138
def _import_module(self) -> QueryModule:
110-
result: Any = import_module(f"{constants.QUERIES_PACKAGE}.{self}")
139+
result: Any = import_module(f"{QUERIES_PACKAGE}.{self}")
111140
return result
112141

113142

114-
logger = logging.getLogger(constants.LOGGER_NAME)
143+
logger = logging.getLogger(LOGGER_NAME)
115144

116145

117146
class TableLogger:
@@ -125,11 +154,11 @@ def __init__(self, file_names: Iterable[str]) -> None:
125154

126155
@staticmethod
127156
def answers() -> TableLogger:
128-
return TableLogger(f"result_{qid}.parquet" for qid in constants.QUERY_IDS)
157+
return TableLogger(f"result_{qid}.parquet" for qid in QUERY_IDS)
129158

130159
@staticmethod
131160
def database() -> TableLogger:
132-
return TableLogger(f"{t}.parquet" for t in constants.DATABASE_TABLE_NAMES)
161+
return TableLogger(f"{t}.parquet" for t in DATABASE_TABLE_NAMES)
133162

134163
def __enter__(self) -> Self:
135164
# header

tpch/constants.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import cache
34
from pathlib import Path
45
from typing import TYPE_CHECKING, get_args
56

@@ -11,12 +12,17 @@
1112
REPO_ROOT = Path(__file__).parent.parent
1213
TPCH_DIR = REPO_ROOT / "tpch"
1314
DATA_DIR = TPCH_DIR / "data"
14-
METADATA_PATH = DATA_DIR / "metadata.csv"
15-
"""For reflection in tests.
1615

17-
E.g. if we *know* the query is not valid for a given `scale_factor`,
18-
then we can determine if a failure is expected.
19-
"""
16+
17+
@cache
18+
def _scale_factor_dir(scale_factor: float) -> Path:
19+
"""Get the data directory for a specific scale factor."""
20+
sf_dir = DATA_DIR / f"sf{scale_factor}"
21+
sf_dir.mkdir(parents=True, exist_ok=True)
22+
return sf_dir
23+
24+
25+
SCALE_FACTOR_DEFAULT = 0.1
2026
DATABASE_TABLE_NAMES = (
2127
"lineitem",
2228
"customer",

0 commit comments

Comments
 (0)