Skip to content

ibis.pyspark col_vals_in_set produces invalid SQL #335

@mark-druffel

Description

@mark-druffel

Sorry, I hate opening so many issues in the same week. I've been testing a big PR on my side and these things are cropping up, don't want to lose site of them even if I code around some for now.

Description

I think col_vals_in_set might not work properly when using ibis.pyspark. It looks like the issue might actually be with narwhals, but I'm having a tough time debugging.

The duckdb validation works, but he pyspark validation generates this SQL which fails (scrubbed for readability):

SELECT 
  campaign_id,
  start_date, 
  end_date, 
  campaign_name, 
  channel_id, 
  target_id, 
  campaign_code, 
  outbound_file_name, 
  execution_cost, 
  campaign_manager, 
  target_cell_names, 
  channel_id IN (303, 304, 441, 461, 621) AS `pb_is_good_` 
FROM comms_media_dev.dart_core.test_campaign WHERE channel_id IN (303, 304, 441, 461, 621) IS NULL

Reproducible example

I created a simple df, wrote to parquet in the working directory, and read it into duckdb & pyspark via ibis:

import ibis
import os
import pandas as pd
import pointblank as pb
import numpy as np
from pyspark.sql import SparkSession
from ibis import _

pdf = pd.DataFrame({
    'id': range(1, 101),
    'value': np.random.rand(100),
    'category': np.random.choice(['A', 'B', 'C'], 100)
})
parquet_path = os.path.join(os.getcwd(), "temp.parquet")
pdf.to_parquet(parquet_path, index=False)
iduck = ibis.duckdb.connect()
df_duck = iduck.read_parquet(parquet_path, table_name = "test")
spark = SparkSession.builder.getOrCreate()
ispark = ibis.pyspark.connect(spark)
df_spark = ispark.read_parquet(parquet_path, table_name = "test")

I setup a validation and ran with duckdb, works as expected:

validate_duck = (
    pb.Validate(data=df_duck)
    .col_vals_in_set(columns = ["category"], set = ["A", "B"])
    .interrogate()
)
validate_duck

I setup the same validation and ran with pyspark, errors on the erroneous SQL:

validate_pyspark = (
    pb.Validate(data=df_spark)
    .col_vals_in_set(columns = ["category"], set = ["A", "B"])
    .interrogate()
)
validate_pyspark

Expected result

The pyspark backend should've produced the same result as the duckdb backend.

Development environment

macOS 14.5
pointblank 0.14.0

Additional context

Here's the error:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:4                                                                                    │
│                                                                                                  │
│   1 validate_pyspark = (                                                                         │
│   2 │   pb.Validate(data=df_spark)                                                               │
│   3 │   .col_vals_in_set(columns = ["category"], set = ["A", "B"])                               │
│ ❱ 4 │   .interrogate()                                                                           │
│   5 )                                                                                            │
│   6 validate_pyspark                                                                             │
│   7                                                                                              │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/pointblank/validate.py:13244 in  │
│ interrogate                                                                                      │
│                                                                                                  │
│   13241 │   │   │   │   # Solely for the col_vals_in_set assertion type, any Null values in the  │
│   13242 │   │   │   │   # `pb_is_good_` column are counted as failing test units                 │
│   13243 │   │   │   │   if assertion_type == "col_vals_in_set":                                  │
│ ❱ 13244 │   │   │   │   │   null_count = _count_null_values_in_column(tbl=results_tbl, column="p │
│   13245 │   │   │   │   │   validation.n_failed += null_count                                    │
│   13246 │   │   │   │                                                                            │
│   13247 │   │   │   │   # For column-value validations, the number of test units is the number o │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/pointblank/_utils.py:375 in      │
│ _count_null_values_in_column                                                                     │
│                                                                                                  │
│   372 │                                                                                          │
│   373 │   # Always collect table if it is a LazyFrame; this is required to get the row count     │
│   374 │   if _is_lazy_frame(tbl_filtered):                                                       │
│ ❱ 375 │   │   tbl_filtered = tbl_filtered.collect()                                              │
│   376 │                                                                                          │
│   377 │   return len(tbl_filtered)                                                               │
│   378                                                                                            │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/narwhals/dataframe.py:2472 in    │
│ collect                                                                                          │
│                                                                                                  │
│   2469 │   │   """                                                                               │
│   2470 │   │   collect = self._compliant_frame.collect                                           │
│   2471 │   │   if backend is None:                                                               │
│ ❱ 2472 │   │   │   return self._dataframe(collect(None, **kwargs), level="full")                 │
│   2473 │   │   eager_backend = Implementation.from_backend(backend)                              │
│   2474 │   │   if can_lazyframe_collect(eager_backend):                                          │
│   2475 │   │   │   return self._dataframe(collect(eager_backend, **kwargs), level="full")        │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/narwhals/_ibis/dataframe.py:114  │
│ in collect                                                                                       │
│                                                                                                  │
│   111 │   │   │   from narwhals._arrow.dataframe import ArrowDataFrame                           │
│   112 │   │   │                                                                                  │
│   113 │   │   │   return ArrowDataFrame(                                                         │
│ ❱ 114 │   │   │   │   to_pyarrow_table(self.native.to_pyarrow()),                                │
│   115 │   │   │   │   validate_backend_version=True,                                             │
│   116 │   │   │   │   version=self._version,                                                     │
│   117 │   │   │   │   validate_column_names=True,                                                │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/ibis/expr/types/relations.py:625 │
│ in to_pyarrow                                                                                    │
│                                                                                                  │
│    622 │   │   limit: int | str | None = None,                                                   │
│    623 │   │   **kwargs: Any,                                                                    │
│    624 │   ) -> pa.Table:                                                                        │
│ ❱  625 │   │   return super().to_pyarrow(params=params, limit=limit, **kwargs)                   │
│    626 │                                                                                         │
│    627 │   def _fast_bind(self, *args, **kwargs):                                                │
│    628 │   │   # allow the first argument to be either a dictionary or a list of values          │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/ibis/expr/types/core.py:605 in   │
│ to_pyarrow                                                                                       │
│                                                                                                  │
│   602 │   │   │   If the passed expression is a Column, a pyarrow array is returned.             │
│   603 │   │   │   If the passed expression is a Scalar, a pyarrow scalar is returned.            │
│   604 │   │   """
│ ❱ 605 │   │   return self._find_backend(use_default=True).to_pyarrow(                            │
│   606 │   │   │   self, params=params, limit=limit, **kwargs                                     │
│   607 │   │   )                                                                                  │
│   608                                                                                            │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/ibis/backends/pyspark/__init__.p │
│ y:1050 in to_pyarrow                                                                             │
│                                                                                                  │
│   1047 │   │                                                                                     │
│   1048 │   │   table_expr = expr.as_table()                                                      │
│   1049 │   │   output = pa.Table.from_pandas(                                                    │
│ ❱ 1050 │   │   │   self.execute(table_expr, params=params, limit=limit, **kwargs),               │
│   1051 │   │   │   preserve_index=False,                                                         │
│   1052 │   │   )                                                                                 │
│   1053 │   │   table = PyArrowData.convert_table(output, table_expr.schema())                    │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/ibis/backends/pyspark/__init__.p │
│ y:506 in execute                                                                                 │
│                                                                                                  │
│    503 │   │                                                                                     │
│    504 │   │   schema = table.schema()                                                           │
│    505 │   │                                                                                     │
│ ❱  506 │   │   with self._safe_raw_sql(sql) as query:                                            │
│    507 │   │   │   df = query.toPandas()  # blocks until finished                                │
│    508 │   │   │   result = PySparkPandasData.convert_table(df, schema)                          │
│    509 │   │   return expr.__pandas_result__(result)                                             │
│                                                                                                  │
│ /opt/homebrew/Cellar/[email protected]/3.11.10/Frameworks/Python.framework/Versions/3.11/lib/python3.1 │
│ 1/contextlib.py:137 in __enter__                                                                 │
│                                                                                                  │
│   134 │   │   # they are only needed for recreation, which is not possible anymore               │
│   135 │   │   del self.args, self.kwds, self.func                                                │
│   136 │   │   try:                                                                               │
│ ❱ 137 │   │   │   return next(self.gen)                                                          │
│   138 │   │   except StopIteration:                                                              │
│   139 │   │   │   raise RuntimeError("generator didn't yield") from None                         │
│   140                                                                                            │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/ibis/backends/pyspark/__init__.p │
│ y:482 in _safe_raw_sql                                                                           │
│                                                                                                  │
│    479 │                                                                                         │
│    480 │   @contextlib.contextmanager                                                            │
│    481 │   def _safe_raw_sql(self, query: str) -> Any:                                           │
│ ❱  482 │   │   yield self.raw_sql(query)                                                         │
│    483 │                                                                                         │
│    484 │   def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:                  │
│    485 │   │   with contextlib.suppress(AttributeError):                                         │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/ibis/backends/pyspark/__init__.p │
│ y:487 in raw_sql                                                                                 │
│                                                                                                  │
│    484 │   def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:                  │
│    485 │   │   with contextlib.suppress(AttributeError):                                         │
│    486 │   │   │   query = query.sql(dialect=self.dialect)                                       │
│ ❱  487 │   │   return self._session.sql(query, **kwargs)                                         │
│    488 │                                                                                         │
│    489 │   def execute(                                                                          │
│    490 │   │   self,                                                                             │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/pyspark/sql/session.py:1631 in   │
│ sql                                                                                              │
│                                                                                                  │
│   1628 │   │   │   │   litArgs = self._jvm.PythonUtils.toArray(                                  │
│   1629 │   │   │   │   │   [_to_java_column(lit(v)) for v in (args or [])]                       │
│   1630 │   │   │   │   )                                                                         │
│ ❱ 1631 │   │   │   return DataFrame(self._jsparkSession.sql(sqlQuery, litArgs), self)            │
│   1632 │   │   finally:                                                                          │
│   1633 │   │   │   if len(kwargs) > 0:                                                           │
│   1634 │   │   │   │   formatter.clear()                                                         │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/py4j/java_gateway.py:1322 in     │
│ __call__                                                                                         │
│                                                                                                  │
│   1319 │   │   │   proto.END_COMMAND_PART                                                        │
│   1320 │   │                                                                                     │
│   1321 │   │   answer = self.gateway_client.send_command(command)                                │
│ ❱ 1322 │   │   return_value = get_return_value(                                                  │
│   1323 │   │   │   answer, self.gateway_client, self.target_id, self.name)                       │
│   1324 │   │                                                                                     │
│   1325 │   │   for temp_arg in temp_args:                                                        │
│                                                                                                  │
│ /Users/m109993/Github/camper/.venv/lib/python3.11/site-packages/pyspark/errors/exceptions/captur │
│ ed.py:185 in deco                                                                                │
│                                                                                                  │
│   182 │   │   │   if not isinstance(converted, UnknownException):                                │
│   183 │   │   │   │   # Hide where the exception came from that shows a non-Pythonic             │
│   184 │   │   │   │   # JVM exception message.                                                   │
│ ❱ 185 │   │   │   │   raise converted from None                                                  │
│   186 │   │   │   else:                                                                          │
│   187 │   │   │   │   raise                                                                      │
│   188                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ParseException: 
[PARSE_SYNTAX_ERROR] Syntax error at or near 'IS'.(line 1, pos 152)

== SQL ==
SELECT `t0`.`id`, `t0`.`value`, `t0`.`category`, `t0`.`category` IN ('A', 'B') AS `pb_is_good_` FROM `test` AS `t0`
WHERE `t0`.`category` IN ('A', 'B') IS NULL
-------------------------------------------------------------------------------------------------------------------
-------------------------------------^^^

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions