44from importlib import import_module
55from typing import TYPE_CHECKING , Any
66
7- import polars as pl
8- import pytest
9- from polars .testing import assert_frame_equal as pl_assert_frame_equal
10-
117import narwhals as nw
128from narwhals .exceptions import NarwhalsError
13- from tpch import constants
9+ from tpch .constants import (
10+ DATABASE_TABLE_NAMES ,
11+ LOGGER_NAME ,
12+ QUERIES_PACKAGE ,
13+ QUERY_IDS ,
14+ SCALE_FACTOR_DEFAULT ,
15+ _scale_factor_dir ,
16+ )
1417
1518if TYPE_CHECKING :
1619 from collections .abc import Iterable
1720 from pathlib import Path
1821
22+ import polars as pl
23+ import pytest
1924 from typing_extensions import Self
2025
2126 from narwhals .typing import FileSource
@@ -51,38 +56,60 @@ def scan(self, source: FileSource) -> nw.LazyFrame[Any]:
5156
5257class Query :
5358 id : QueryID
54- paths : tuple [Path , ...]
59+ table_names : tuple [str , ...]
60+ scale_factor : float
5561 _into_xfails : tuple [tuple [Predicate , str , XFailRaises ], ...]
5662 _into_skips : tuple [tuple [Predicate , str ], ...]
5763
58- def __init__ (self , query_id : QueryID , paths : tuple [Path , ...]) -> None :
64+ def __init__ (self , query_id : QueryID , table_names : tuple [str , ...]) -> None :
5965 self .id = query_id
60- self .paths = paths
66+ self .table_names = table_names
6167 self ._into_xfails = ()
6268 self ._into_skips = ()
69+ self .scale_factor = SCALE_FACTOR_DEFAULT
6370
6471 def __repr__ (self ) -> str :
6572 return self .id
6673
67- def execute (
68- self , backend : Backend , scale_factor : float , request : pytest .FixtureRequest
69- ) -> None :
70- self ._apply_skips (backend , scale_factor )
71- data = tuple (backend .scan (fp .as_posix ()) for fp in self .paths )
74+ def inputs (self , backend : Backend ) -> tuple [nw .LazyFrame [Any ], ...]:
75+ """Get the frame inputs for this query at the given scale factor."""
76+ sf_dir = _scale_factor_dir (self .scale_factor )
77+ return tuple (
78+ backend .scan ((sf_dir / f"{ name } .parquet" ).as_posix ())
79+ for name in self .table_names
80+ )
81+
82+ def expected (self ) -> pl .DataFrame :
83+ import polars as pl
84+
85+ sf_dir = _scale_factor_dir (self .scale_factor )
86+ return pl .read_parquet (sf_dir / f"result_{ self } .parquet" )
87+
88+ def execute (self , backend : Backend , request : pytest .FixtureRequest ) -> None :
89+ from polars .testing import assert_frame_equal
90+
91+ self ._apply_skips (backend )
92+ data = self .inputs (backend )
7293 query = self ._import_module ().query
94+
7395 try :
7496 result = query (* data ).lazy ().collect ("polars" ).to_polars ()
7597 except NarwhalsError as exc :
76- msg = f"Query [{ self } -{ backend } ] ({ scale_factor = } ) failed with the following error in Narwhals:\n { exc } "
98+ msg = f"Query [{ self } -{ backend } ] ({ self . scale_factor = } ) failed with the following error in Narwhals:\n { exc } "
7799 raise RuntimeError (msg ) from exc
78- self ._apply_xfails (backend , scale_factor , request )
79- expected = pl .read_parquet (constants .DATA_DIR / f"result_{ self } .parquet" )
100+
101+ self ._apply_xfails (backend , request )
102+ expected = self .expected ()
80103 try :
81- pl_assert_frame_equal (expected , result , check_dtypes = False )
104+ assert_frame_equal (expected , result , check_dtypes = False )
82105 except AssertionError as exc :
83- msg = f"Query [{ self } -{ backend } ] ({ scale_factor = } ) resulted in wrong answer:\n { exc } "
106+ msg = f"Query [{ self } -{ backend } ] ({ self . scale_factor = } ) resulted in wrong answer:\n { exc } "
84107 raise AssertionError (msg ) from exc
85108
109+ def with_scale_factor (self , scale_factor : float , / ) -> Query :
110+ self .scale_factor = scale_factor
111+ return self
112+
86113 def with_skip (self , predicate : Predicate , reason : str ) -> Query :
87114 self ._into_skips = (* self ._into_skips , (predicate , reason ))
88115 return self
@@ -93,25 +120,27 @@ def with_xfail(
93120 self ._into_xfails = (* self ._into_xfails , (predicate , reason , raises ))
94121 return self
95122
96- def _apply_skips (self , backend : Backend , scale_factor : float ) -> None :
123+ def _apply_skips (self , backend : Backend ) -> None :
124+ import pytest
125+
97126 for predicate , reason in self ._into_skips :
98- if predicate (backend , scale_factor ):
127+ if predicate (backend , self . scale_factor ):
99128 pytest .skip (reason )
100129
101- def _apply_xfails (
102- self , backend : Backend , scale_factor : float , request : pytest . FixtureRequest
103- ) -> None :
130+ def _apply_xfails (self , backend : Backend , request : pytest . FixtureRequest ) -> None :
131+ import pytest
132+
104133 for predicate , reason , raises in self ._into_xfails :
105- condition = predicate (backend , scale_factor )
134+ condition = predicate (backend , self . scale_factor )
106135 mark = pytest .mark .xfail (condition , reason = reason , raises = raises )
107136 request .applymarker (mark )
108137
109138 def _import_module (self ) -> QueryModule :
110- result : Any = import_module (f"{ constants . QUERIES_PACKAGE } .{ self } " )
139+ result : Any = import_module (f"{ QUERIES_PACKAGE } .{ self } " )
111140 return result
112141
113142
114- logger = logging .getLogger (constants . LOGGER_NAME )
143+ logger = logging .getLogger (LOGGER_NAME )
115144
116145
117146class TableLogger :
@@ -125,11 +154,11 @@ def __init__(self, file_names: Iterable[str]) -> None:
125154
126155 @staticmethod
127156 def answers () -> TableLogger :
128- return TableLogger (f"result_{ qid } .parquet" for qid in constants . QUERY_IDS )
157+ return TableLogger (f"result_{ qid } .parquet" for qid in QUERY_IDS )
129158
130159 @staticmethod
131160 def database () -> TableLogger :
132- return TableLogger (f"{ t } .parquet" for t in constants . DATABASE_TABLE_NAMES )
161+ return TableLogger (f"{ t } .parquet" for t in DATABASE_TABLE_NAMES )
133162
134163 def __enter__ (self ) -> Self :
135164 # header
0 commit comments