Skip to content

Commit de7347e

Browse files
authored
Merge branch 'main' into simp-pandas-group-by
2 parents c8cbd78 + 4e43f54 commit de7347e

File tree

4 files changed

+24
-17
lines changed

4 files changed

+24
-17
lines changed

narwhals/_duckdb/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,15 @@ def generate_partition_by_sql(*partition_by: str | Expression) -> str:
304304

305305

306306
def generate_order_by_sql(
307-
*order_by: str | Expression, ascending: bool, nulls_first: bool
307+
*order_by: str | Expression, descending: bool, nulls_last: bool
308308
) -> str:
309309
if not order_by:
310310
return ""
311-
nulls = "nulls first" if nulls_first else "nulls last"
312-
if ascending:
313-
by_sql = ", ".join([f"{parse_into_expression(x)} asc {nulls}" for x in order_by])
314-
else:
311+
nulls = "nulls last" if nulls_last else "nulls first"
312+
if descending:
315313
by_sql = ", ".join([f"{parse_into_expression(x)} desc {nulls}" for x in order_by])
314+
else:
315+
by_sql = ", ".join([f"{parse_into_expression(x)} asc {nulls}" for x in order_by])
316316
return f"order by {by_sql}"
317317

318318

@@ -335,9 +335,7 @@ def window_expression(
335335
msg = f"DuckDB>=1.3.0 is required for this operation. Found: DuckDB {duckdb.__version__}"
336336
raise NotImplementedError(msg) from exc
337337
pb = generate_partition_by_sql(*partition_by)
338-
ob = generate_order_by_sql(
339-
*order_by, ascending=not descending, nulls_first=not nulls_last
340-
)
338+
ob = generate_order_by_sql(*order_by, descending=descending, nulls_last=nulls_last)
341339

342340
if rows_start and rows_end:
343341
rows = f"rows between {rows_start} and {rows_end}"

tests/expr_and_series/dt/truncate_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,6 @@ def test_truncate(
5353
every: str,
5454
expected: list[datetime],
5555
) -> None:
56-
if any(x in str(constructor) for x in ("sqlframe", "pyspark")):
57-
# TODO(marco): investigate pyspark, it also localizes to UTC here.
58-
request.applymarker(
59-
pytest.mark.xfail(reason="https://github.com/eakmanrq/sqlframe/issues/383")
60-
)
6156
if every.endswith("ns") and any(
6257
x in str(constructor) for x in ("polars", "duckdb", "pyspark", "ibis")
6358
):
@@ -109,11 +104,10 @@ def test_truncate_multiples(
109104
every: str,
110105
expected: list[datetime],
111106
) -> None:
112-
if any(x in str(constructor) for x in ("sqlframe", "cudf", "pyspark", "duckdb")):
107+
if any(x in str(constructor) for x in ("cudf", "pyspark", "duckdb")):
113108
# Reasons:
114-
# - sqlframe: https://github.com/eakmanrq/sqlframe/issues/383
115109
# - cudf: https://github.com/rapidsai/cudf/issues/18654
116-
# - pyspark: Only multiple 1 is currently supported
110+
# - pyspark/sqlframe: Only multiple 1 is currently supported
117111
request.applymarker(pytest.mark.xfail())
118112
if every.endswith("ns") and any(
119113
x in str(constructor) for x in ("polars", "duckdb", "ibis")

tests/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def assert_equal_data(result: Any, expected: Mapping[str, Any]) -> None:
7777
hasattr(result, "_compliant_frame")
7878
and result._compliant_frame._implementation is Implementation.IBIS
7979
)
80+
is_spark_like = (
81+
hasattr(result, "_compliant_frame")
82+
and result._compliant_frame._implementation.is_spark_like()
83+
)
8084
if is_duckdb:
8185
result = from_native(result.to_native().arrow())
8286
if is_ibis:
@@ -122,6 +126,17 @@ def assert_equal_data(result: Any, expected: Mapping[str, Any]) -> None:
122126
are_equivalent_values = pd.isna(rhs)
123127
elif type(lhs) is date and type(rhs) is datetime:
124128
are_equivalent_values = datetime(lhs.year, lhs.month, lhs.day) == rhs
129+
elif (
130+
is_spark_like
131+
and isinstance(lhs, datetime)
132+
and isinstance(rhs, datetime)
133+
and rhs.tzinfo is None
134+
and lhs.tzinfo
135+
):
136+
# PySpark converts timezone-naive to timezone-aware by default in many cases.
137+
# For now, we just assert that the local result matches the expected one.
138+
# https://github.com/narwhals-dev/narwhals/issues/2793
139+
are_equivalent_values = lhs.replace(tzinfo=None) == rhs
125140
else:
126141
are_equivalent_values = lhs == rhs
127142

tests/v1_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def test_with_row_index(constructor: Constructor) -> None:
396396

397397
frame = nw_v1.from_native(constructor(data))
398398

399-
msg = r".*argument after \* must be an iterable, not NoneType$"
399+
msg = r"argument after \* must be an iterable, not NoneType|has no len"
400400
context = (
401401
pytest.raises(TypeError, match=msg)
402402
if any(x in str(constructor) for x in ("duckdb", "pyspark"))

0 commit comments

Comments
 (0)