Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ including other versions of pandas.
Enhancements
~~~~~~~~~~~~

.. _whatsnew_220.enhancements.enhancement1:
.. _whatsnew_220.enhancements.pyarrow_on_bad_lines:

enhancement1
^^^^^^^^^^^^

.. _whatsnew_220.enhancements.enhancement2:

Expand All @@ -28,8 +26,8 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
-
-
- Addition of the capability to handle malformed lines in CSV files to the the `PyArrow <https://arrow.apache.org/docs/python/index.html>`_ engine using the ``on_bad_lines`` parameter. (:issue:`54480`)


.. ---------------------------------------------------------------------------
.. _whatsnew_220.notable_bug_fixes:
Expand Down
27 changes: 27 additions & 0 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import warnings

from pandas._config import using_pyarrow_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import ParserWarning
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.inference import is_integer

Expand Down Expand Up @@ -85,6 +88,30 @@ def _get_pyarrow_options(self) -> None:
and option_name
in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
}

if "on_bad_lines" in self.kwds:
on_bad_lines = self.kwds["on_bad_lines"]
if callable(on_bad_lines):
self.parse_options["invalid_row_handler"] = on_bad_lines
elif on_bad_lines == ParserBase.BadLineHandleMethod.ERROR:
self.parse_options[
"invalid_row_handler"
] = None # PyArrow raises an exception by default
elif on_bad_lines == ParserBase.BadLineHandleMethod.WARN:

def handle_warning(invalid_row):
warnings.warn(
f"Expected {invalid_row.expected_columns} columns, but found "
f"{invalid_row.actual_columns}: {invalid_row.text}",
ParserWarning,
stacklevel=find_stack_level(),
)
return "skip"

self.parse_options["invalid_row_handler"] = handle_warning
elif on_bad_lines == ParserBase.BadLineHandleMethod.SKIP:
self.parse_options["invalid_row_handler"] = lambda _: "skip"

self.convert_options = {
option_name: option_value
for option_name, option_value in self.kwds.items()
Expand Down
13 changes: 10 additions & 3 deletions pandas/io/parsers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@
expected, a ``ParserWarning`` will be emitted while dropping extra elements.
Only supported when ``engine='python'``

.. versionchanged:: 2.2.0

- Callable, function with signature
as described in `pyarrow documentation
<https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html
#pyarrow.csv.ParseOptions.invalid_row_handler>_` when ``engine='pyarrow'``

delim_whitespace : bool, default False
Specifies whether or not whitespace (e.g. ``' '`` or ``'\\t'``) will be
used as the ``sep`` delimiter. Equivalent to setting ``sep='\\s+'``. If this option
Expand Down Expand Up @@ -484,7 +491,6 @@ class _Fwf_Defaults(TypedDict):
"thousands",
"memory_map",
"dialect",
"on_bad_lines",
"delim_whitespace",
"quoting",
"lineterminator",
Expand Down Expand Up @@ -2053,9 +2059,10 @@ def _refine_defaults_read(
elif on_bad_lines == "skip":
kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP
elif callable(on_bad_lines):
if engine != "python":
if engine not in ["python", "pyarrow"]:
raise ValueError(
"on_bad_line can only be a callable function if engine='python'"
"on_bad_line can only be a callable function "
"if engine='python' or 'pyarrow'"
)
kwds["on_bad_lines"] = on_bad_lines
else:
Expand Down
104 changes: 82 additions & 22 deletions pandas/tests/io/parser/common/test_read_errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Tests that work on both the Python and C engines but do not have a
Tests that work on both the Python, C and PyArrow engines but do not have a
specific classification into the other test modules.
"""
import codecs
Expand All @@ -15,12 +15,20 @@
from pandas.errors import (
EmptyDataError,
ParserError,
ParserWarning,
)

from pandas import DataFrame
import pandas._testing as tm

pytestmark = pytest.mark.usefixtures("pyarrow_skip")
# PyArrow's error types are not available by default
try:
from pyarrow import ArrowInvalid
except ImportError:
pass

xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")
skip_pyarrow = pytest.mark.usefixtures("pyarrow_skip")


def test_empty_decimal_marker(all_parsers):
Expand All @@ -32,10 +40,17 @@ def test_empty_decimal_marker(all_parsers):
msg = "Only length-1 decimal markers supported"
parser = all_parsers

if parser.engine == "pyarrow":
msg = (
"only single character unicode strings can be "
"converted to Py_UCS4, got length 0"
)

with pytest.raises(ValueError, match=msg):
parser.read_csv(StringIO(data), decimal="")


@skip_pyarrow
def test_bad_stream_exception(all_parsers, csv_dir_path):
# see gh-13652
#
Expand All @@ -56,6 +71,7 @@ def test_bad_stream_exception(all_parsers, csv_dir_path):
parser.read_csv(stream)


@skip_pyarrow
def test_malformed(all_parsers):
# see gh-6607
parser = all_parsers
Expand All @@ -70,6 +86,7 @@ def test_malformed(all_parsers):
parser.read_csv(StringIO(data), header=1, comment="#")


@skip_pyarrow
@pytest.mark.parametrize("nrows", [5, 3, None])
def test_malformed_chunks(all_parsers, nrows):
data = """ignore
Expand All @@ -89,6 +106,7 @@ def test_malformed_chunks(all_parsers, nrows):
reader.read(nrows)


@skip_pyarrow
def test_catch_too_many_names(all_parsers):
# see gh-5156
data = """\
Expand All @@ -108,6 +126,7 @@ def test_catch_too_many_names(all_parsers):
parser.read_csv(StringIO(data), header=0, names=["a", "b", "c", "d"])


@skip_pyarrow
@pytest.mark.parametrize("nrows", [0, 1, 2, 3, 4, 5])
def test_raise_on_no_columns(all_parsers, nrows):
parser = all_parsers
Expand Down Expand Up @@ -135,11 +154,16 @@ def test_suppress_error_output(all_parsers, capsys):
data = "a\n1\n1,2,3\n4\n5,6,7"
expected = DataFrame({"a": [1, 4]})

result = parser.read_csv(StringIO(data), on_bad_lines="skip")
tm.assert_frame_equal(result, expected)
if parser.engine == "pyarrow":
with tm.assert_produces_warning(False):
result = parser.read_csv(StringIO(data), on_bad_lines="skip")
tm.assert_frame_equal(result, expected)
else:
result = parser.read_csv(StringIO(data), on_bad_lines="skip")
tm.assert_frame_equal(result, expected)

captured = capsys.readouterr()
assert captured.err == ""
captured = capsys.readouterr()
assert captured.err == ""


def test_error_bad_lines(all_parsers):
Expand All @@ -148,7 +172,13 @@ def test_error_bad_lines(all_parsers):
data = "a\n1\n1,2,3\n4\n5,6,7"

msg = "Expected 1 fields in line 3, saw 3"
with pytest.raises(ParserError, match=msg):
ex_type = ParserError

if parser.engine == "pyarrow":
msg = "CSV parse error: Expected 1 columns, got 3: 1,2,3"
ex_type = ArrowInvalid

with pytest.raises(ex_type, match=msg):
parser.read_csv(StringIO(data), on_bad_lines="error")


Expand All @@ -158,12 +188,21 @@ def test_warn_bad_lines(all_parsers, capsys):
data = "a\n1\n1,2,3\n4\n5,6,7"
expected = DataFrame({"a": [1, 4]})

result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
if parser.engine == "pyarrow":
with tm.assert_produces_warning(
ParserWarning,
check_stacklevel=False,
match="Expected 1 columns, but found 3: 1,2,3",
):
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
else:
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)

captured = capsys.readouterr()
assert "Skipping line 3" in captured.err
assert "Skipping line 5" in captured.err
captured = capsys.readouterr()
assert "Skipping line 3" in captured.err
assert "Skipping line 5" in captured.err


def test_read_csv_wrong_num_columns(all_parsers):
Expand All @@ -175,11 +214,17 @@ def test_read_csv_wrong_num_columns(all_parsers):
"""
parser = all_parsers
msg = "Expected 6 fields in line 3, saw 7"
ex_type = ParserError

with pytest.raises(ParserError, match=msg):
if parser.engine == "pyarrow":
msg = "Expected 6 columns, got 7: 6,7,8,9,10,11,12"
ex_type = ArrowInvalid

with pytest.raises(ex_type, match=msg):
parser.read_csv(StringIO(data))


@skip_pyarrow
def test_null_byte_char(request, all_parsers):
# see gh-2741
data = "\x00,foo"
Expand All @@ -202,6 +247,7 @@ def test_null_byte_char(request, all_parsers):
parser.read_csv(StringIO(data), names=names)


@skip_pyarrow
@pytest.mark.filterwarnings("always::ResourceWarning")
def test_open_file(request, all_parsers):
# GH 39024
Expand Down Expand Up @@ -235,13 +281,17 @@ def test_bad_header_uniform_error(all_parsers):
parser = all_parsers
data = "+++123456789...\ncol1,col2,col3,col4\n1,2,3,4\n"
msg = "Expected 2 fields in line 2, saw 4"
ex_type = ParserError
if parser.engine == "c":
msg = (
"Could not construct index. Requested to use 1 "
"number of columns, but 3 left to parse."
)
elif parser.engine == "pyarrow":
ex_type = ArrowInvalid
msg = "CSV parse error: Expected 1 columns, got 4: col1,col2,col3,col4"

with pytest.raises(ParserError, match=msg):
with pytest.raises(ex_type, match=msg):
parser.read_csv(StringIO(data), index_col=0, on_bad_lines="error")


Expand All @@ -256,17 +306,27 @@ def test_on_bad_lines_warn_correct_formatting(all_parsers, capsys):
"""
expected = DataFrame({"1": "a", "2": ["b"] * 2})

result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
# pyarrow engine uses warnings instead of directly printing to stderr
if parser.engine == "pyarrow":
with tm.assert_produces_warning(
ParserWarning,
check_stacklevel=False,
match="Expected 2 columns, but found 3: a,b,c",
):
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
else:
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)

captured = capsys.readouterr()
if parser.engine == "c":
warn = """Skipping line 3: expected 2 fields, saw 3
captured = capsys.readouterr()
if parser.engine == "c":
warn = """Skipping line 3: expected 2 fields, saw 3
Skipping line 4: expected 2 fields, saw 3

"""
else:
warn = """Skipping line 3: Expected 2 fields in line 3, saw 3
else:
warn = """Skipping line 3: Expected 2 fields in line 3, saw 3
Skipping line 4: Expected 2 fields in line 4, saw 3
"""
assert captured.err == warn
assert captured.err == warn
10 changes: 7 additions & 3 deletions pandas/tests/io/parser/test_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,17 @@ def test_pyarrow_engine(self):
with pytest.raises(ValueError, match=msg):
read_csv(StringIO(data), engine="pyarrow", **kwargs)

def test_on_bad_lines_callable_python_only(self, all_parsers):
def test_on_bad_lines_callable_python_or_pyarrow(self, all_parsers):
# GH 5686
# GH 54643
sio = StringIO("a,b\n1,2")
bad_lines_func = lambda x: x
parser = all_parsers
if all_parsers.engine != "python":
msg = "on_bad_line can only be a callable function if engine='python'"
if all_parsers.engine not in ["python", "pyarrow"]:
msg = (
"on_bad_line can only be a callable "
"function if engine='python' or 'pyarrow'"
)
with pytest.raises(ValueError, match=msg):
parser.read_csv(sio, on_bad_lines=bad_lines_func)
else:
Expand Down