Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions tpch/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@
from importlib import import_module
from typing import TYPE_CHECKING, Any

import polars as pl
import pytest
from polars.testing import assert_frame_equal as pl_assert_frame_equal

import narwhals as nw
from narwhals.exceptions import NarwhalsError
from tpch import constants
from tpch.constants import (
DATABASE_TABLE_NAMES,
LOGGER_NAME,
QUERIES_PACKAGE,
QUERY_IDS,
SCALE_FACTOR_DEFAULT,
_scale_factor_dir,
)
Comment on lines +9 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this an intentional change, or lost in the conflicts?

In (134875f) I switched to constants so that the imports didn't take up so many lines 😢

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see! Sorry feel free to revert it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep it for now, this next branch has waaaay more of a negative diff 😄

tpch/refactor-cli-sf-folders...tpch/refactor-cli-choices


if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path

import polars as pl
import pytest
from typing_extensions import Self

from narwhals.typing import FileSource
Expand Down Expand Up @@ -51,38 +56,60 @@ def scan(self, source: FileSource) -> nw.LazyFrame[Any]:

class Query:
id: QueryID
paths: tuple[Path, ...]
table_names: tuple[str, ...]
scale_factor: float
_into_xfails: tuple[tuple[Predicate, str, XFailRaises], ...]
_into_skips: tuple[tuple[Predicate, str], ...]

def __init__(self, query_id: QueryID, paths: tuple[Path, ...]) -> None:
def __init__(self, query_id: QueryID, table_names: tuple[str, ...]) -> None:
self.id = query_id
self.paths = paths
self.table_names = table_names
self._into_xfails = ()
self._into_skips = ()
self.scale_factor = SCALE_FACTOR_DEFAULT

def __repr__(self) -> str:
return self.id

def execute(
self, backend: Backend, scale_factor: float, request: pytest.FixtureRequest
) -> None:
self._apply_skips(backend, scale_factor)
data = tuple(backend.scan(fp.as_posix()) for fp in self.paths)
def inputs(self, backend: Backend) -> tuple[nw.LazyFrame[Any], ...]:
"""Get the frame inputs for this query at the given scale factor."""
sf_dir = _scale_factor_dir(self.scale_factor)
return tuple(
backend.scan((sf_dir / f"{name}.parquet").as_posix())
for name in self.table_names
)

def expected(self) -> pl.DataFrame:
import polars as pl

sf_dir = _scale_factor_dir(self.scale_factor)
return pl.read_parquet(sf_dir / f"result_{self}.parquet")

def execute(self, backend: Backend, request: pytest.FixtureRequest) -> None:
from polars.testing import assert_frame_equal

self._apply_skips(backend)
data = self.inputs(backend)
query = self._import_module().query

try:
result = query(*data).lazy().collect("polars").to_polars()
except NarwhalsError as exc:
msg = f"Query [{self}-{backend}] ({scale_factor=}) failed with the following error in Narwhals:\n{exc}"
msg = f"Query [{self}-{backend}] ({self.scale_factor=}) failed with the following error in Narwhals:\n{exc}"
raise RuntimeError(msg) from exc
self._apply_xfails(backend, scale_factor, request)
expected = pl.read_parquet(constants.DATA_DIR / f"result_{self}.parquet")

self._apply_xfails(backend, request)
expected = self.expected()
try:
pl_assert_frame_equal(expected, result, check_dtypes=False)
assert_frame_equal(expected, result, check_dtypes=False)
except AssertionError as exc:
msg = f"Query [{self}-{backend}] ({scale_factor=}) resulted in wrong answer:\n{exc}"
msg = f"Query [{self}-{backend}] ({self.scale_factor=}) resulted in wrong answer:\n{exc}"
raise AssertionError(msg) from exc

def with_scale_factor(self, scale_factor: float, /) -> Query:
self.scale_factor = scale_factor
return self

def with_skip(self, predicate: Predicate, reason: str) -> Query:
self._into_skips = (*self._into_skips, (predicate, reason))
return self
Expand All @@ -93,25 +120,27 @@ def with_xfail(
self._into_xfails = (*self._into_xfails, (predicate, reason, raises))
return self

def _apply_skips(self, backend: Backend, scale_factor: float) -> None:
def _apply_skips(self, backend: Backend) -> None:
import pytest

for predicate, reason in self._into_skips:
if predicate(backend, scale_factor):
if predicate(backend, self.scale_factor):
pytest.skip(reason)

def _apply_xfails(
self, backend: Backend, scale_factor: float, request: pytest.FixtureRequest
) -> None:
def _apply_xfails(self, backend: Backend, request: pytest.FixtureRequest) -> None:
import pytest

for predicate, reason, raises in self._into_xfails:
condition = predicate(backend, scale_factor)
condition = predicate(backend, self.scale_factor)
mark = pytest.mark.xfail(condition, reason=reason, raises=raises)
request.applymarker(mark)

def _import_module(self) -> QueryModule:
result: Any = import_module(f"{constants.QUERIES_PACKAGE}.{self}")
result: Any = import_module(f"{QUERIES_PACKAGE}.{self}")
return result


logger = logging.getLogger(constants.LOGGER_NAME)
logger = logging.getLogger(LOGGER_NAME)


class TableLogger:
Expand All @@ -125,11 +154,11 @@ def __init__(self, file_names: Iterable[str]) -> None:

@staticmethod
def answers() -> TableLogger:
return TableLogger(f"result_{qid}.parquet" for qid in constants.QUERY_IDS)
return TableLogger(f"result_{qid}.parquet" for qid in QUERY_IDS)

@staticmethod
def database() -> TableLogger:
return TableLogger(f"{t}.parquet" for t in constants.DATABASE_TABLE_NAMES)
return TableLogger(f"{t}.parquet" for t in DATABASE_TABLE_NAMES)

def __enter__(self) -> Self:
# header
Expand Down
16 changes: 11 additions & 5 deletions tpch/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING, get_args

Expand All @@ -11,12 +12,17 @@
REPO_ROOT = Path(__file__).parent.parent
TPCH_DIR = REPO_ROOT / "tpch"
DATA_DIR = TPCH_DIR / "data"
METADATA_PATH = DATA_DIR / "metadata.csv"
"""For reflection in tests.

E.g. if we *know* the query is not valid for a given `scale_factor`,
then we can determine if a failure is expected.
"""

@cache
def _scale_factor_dir(scale_factor: float) -> Path:
"""Get the data directory for a specific scale factor."""
sf_dir = DATA_DIR / f"sf{scale_factor}"
sf_dir.mkdir(parents=True, exist_ok=True)
return sf_dir


SCALE_FACTOR_DEFAULT = 0.1
DATABASE_TABLE_NAMES = (
"lineitem",
"customer",
Expand Down
Loading
Loading