diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 10b56011c9640..ca5410736c228 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -208,6 +208,7 @@ Other enhancements - :meth:`DataFrame.apply` supports using third-party execution engines like the Bodo.ai JIT compiler (:issue:`60668`) - :meth:`DataFrame.iloc` and :meth:`Series.iloc` now support boolean masks in ``__getitem__`` for more consistent indexing behavior (:issue:`60994`) - :meth:`DataFrame.to_csv` and :meth:`Series.to_csv` now support Python's new-style format strings (e.g., ``"{:.6f}"``) for the ``float_format`` parameter, in addition to old-style ``%`` format strings and callables. This allows for more flexible and modern formatting of floating point numbers when exporting to CSV. (:issue:`49580`) +- :meth:`DataFrame.to_sql` and :func:`to_sql` now accept a ``nullable`` parameter to specify which columns should allow NULL values. This allows control over the NOT NULL constraint when creating SQL tables, supporting use cases like programmatic table creation from data dictionaries (:issue:`63116`) - :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`) - :meth:`Rolling.agg`, :meth:`Expanding.agg` and :meth:`ExponentialMovingWindow.agg` now accept :class:`NamedAgg` aggregations through ``**kwargs`` (:issue:`28333`) - :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`) @@ -232,7 +233,6 @@ Other enhancements - Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) - Switched wheel upload to **PyPI Trusted Publishing** (OIDC) for release-tag pushes in ``wheels.yml``. (:issue:`61718`) -- .. --------------------------------------------------------------------------- .. _whatsnew_300.notable_bug_fixes: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 25e0aa6b8f072..5802ef6b58eb7 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2798,6 +2798,7 @@ def to_sql( chunksize: int | None = None, dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, + nullable: dict[str, bool] | None = None, ) -> int | None: """ Write records stored in a DataFrame to a SQL database. @@ -2861,6 +2862,25 @@ def to_sql( Details and a sample callable implementation can be found in the section :ref:`insert method `. + nullable : dict, optional + Specifies whether columns should allow NULL values. If a dictionary is used, + the keys should be the column names and the values should be boolean values. + ``True`` indicates the column is nullable (can contain NULL), + ``False`` indicates the column is NOT NULL. + + For SQLAlchemy connections: If a column is not specified in the dictionary, + the default behavior is typically nullable=True. + + For ADBC connections: The PyArrow table schema is modified to set the + nullability constraint. If data contains NULL values for a column marked + as ``nullable=False``, a ValueError will be raised. + + This parameter only applies when creating a new table or replacing an existing + table (i.e., when ``if_exists='fail'`` and table doesn't exist, ``if_exists='replace'``). + When ``if_exists='append'``, this parameter is ignored as the table schema + already exists. + + .. versionadded:: 3.0.0 Returns ------- @@ -3013,6 +3033,26 @@ def to_sql( ... conn.execute(text("SELECT * FROM integers")).fetchall() [(1,), (None,), (2,)] + Specify nullable constraints when creating a table. This is useful for + enforcing NOT NULL constraints based on data dictionaries or schemas. + + >>> df = pd.DataFrame({'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']}) + >>> df.to_sql(name='users_with_constraints', con=engine, if_exists='replace', + ... index=False, nullable={'id': False, 'name': False}) + 3 + + The table is created with NOT NULL constraints on id and name columns: + + >>> with engine.connect() as conn: + ... result = conn.execute(text( + ... "SELECT sql FROM sqlite_master WHERE name='users_with_constraints'" + ... )).fetchone() # doctest:+SKIP + ... print(result[0]) # doctest:+SKIP + CREATE TABLE users_with_constraints ( + id BIGINT NOT NULL, + name TEXT NOT NULL + ) + .. versionadded:: 2.2.0 pandas now supports writing via ADBC drivers @@ -3042,6 +3082,7 @@ def to_sql( chunksize=chunksize, dtype=dtype, method=method, + nullable=nullable, ) @final diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 52adbd42c4479..b41952e995246 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -751,6 +751,7 @@ def to_sql( dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, engine: str = "auto", + nullable: dict[str, bool] | None = None, **engine_kwargs, ) -> int | None: """ @@ -808,6 +809,25 @@ def to_sql( SQL engine library to use. If 'auto', then the option ``io.sql.engine`` is used. The default ``io.sql.engine`` behavior is 'sqlalchemy' + nullable : dict, optional + Specifies whether columns should allow NULL values. If a dictionary is used, + the keys should be the column names and the values should be boolean values. + ``True`` indicates the column is nullable (can contain NULL), + ``False`` indicates the column is NOT NULL. + + For SQLAlchemy connections: If a column is not specified in the dictionary, + the default behavior is typically nullable=True. + + For ADBC connections: The PyArrow table schema is modified to set the + nullability constraint. If data contains NULL values for a column marked + as ``nullable=False``, a ValueError will be raised. + + This parameter only applies when creating a new table or replacing an existing + table (i.e., when ``if_exists='fail'`` and table doesn't exist, ``if_exists='replace'``). + When ``if_exists='append'``, this parameter is ignored as the table schema + already exists. + + .. versionadded:: 3.0.0 **engine_kwargs Any additional kwargs are passed to the engine. @@ -849,6 +869,7 @@ def to_sql( dtype=dtype, method=method, engine=engine, + nullable=nullable, **engine_kwargs, ) @@ -944,6 +965,7 @@ def __init__( schema=None, keys=None, dtype: DtypeArg | None = None, + nullable: dict[str, bool] | None = None, ) -> None: self.name = name self.pd_sql = pandas_sql_engine @@ -954,6 +976,7 @@ def __init__( self.if_exists = if_exists self.keys = keys self.dtype = dtype + self.nullable = nullable if frame is not None: # We want to initialize based on a dataframe @@ -1267,10 +1290,15 @@ def _create_table_setup(self): column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type) - columns: list[Any] = [ - Column(name, typ, index=is_index) - for name, typ, is_index in column_names_and_types - ] + columns: list[Any] = [] + for name, typ, is_index in column_names_and_types: + if self.nullable is not None and name in self.nullable: + nullable_value = self.nullable[name] + columns.append( + Column(name, typ, index=is_index, nullable=nullable_value) + ) + else: + columns.append(Column(name, typ, index=is_index)) if self.keys is not None: if not is_list_like(self.keys): @@ -1504,6 +1532,7 @@ def to_sql( dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, engine: str = "auto", + nullable: dict[str, bool] | None = None, **engine_kwargs, ) -> int | None: pass @@ -1893,6 +1922,7 @@ def prep_table( index_label=None, schema=None, dtype: DtypeArg | None = None, + nullable: dict[str, bool] | None = None, ) -> SQLTable: """ Prepares table in the database for data insertion. Creates it if needed, etc. @@ -1928,6 +1958,7 @@ def prep_table( index_label=index_label, schema=schema, dtype=dtype, + nullable=nullable, ) table.create() return table @@ -1973,6 +2004,7 @@ def to_sql( dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, engine: str = "auto", + nullable: dict[str, bool] | None = None, **engine_kwargs, ) -> int | None: """ @@ -2032,6 +2064,7 @@ def to_sql( index_label=index_label, schema=schema, dtype=dtype, + nullable=nullable, ) total_inserted = sql_engine.insert_records( @@ -2333,6 +2366,7 @@ def to_sql( dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, engine: str = "auto", + nullable: dict[str, bool] | None = None, **engine_kwargs, ) -> int | None: """ @@ -2362,6 +2396,15 @@ def to_sql( Raises NotImplementedError method : {None', 'multi', callable}, default None Raises NotImplementedError + nullable : dict, optional + Specifies whether columns should allow NULL values. If a dictionary is used, + the keys should be column names and the values should be boolean values. + ``True`` indicates the column is nullable (can contain NULL), + ``False`` indicates the column is NOT NULL. + Only applies when creating or replacing tables (``if_exists='fail'`` or + ``if_exists='replace'``). When ``if_exists='append'``, this parameter is + ignored. If the data contains NULL values for a column marked as + ``nullable=False``, a ValueError will be raised. engine : {'auto', 'sqlalchemy'}, default 'auto' Raises NotImplementedError if not set to 'auto' """ @@ -2410,6 +2453,27 @@ def to_sql( except pa.ArrowNotImplementedError as exc: raise ValueError("datatypes not supported") from exc + if nullable and mode == "create": + current_schema = tbl.schema + new_fields = [] + + for field in current_schema: + if field.name in nullable: + if not nullable[field.name]: + col_data = tbl.column(field.name) + if col_data.null_count > 0: + raise ValueError( + f"Column '{field.name}' contains {col_data.null_count} " + f"null value(s) but nullable=False was specified" + ) + new_field = field.with_nullable(nullable[field.name]) + new_fields.append(new_field) + else: + new_fields.append(field) + + new_schema = pa.schema(new_fields, metadata=current_schema.metadata) + tbl = tbl.cast(new_schema) + with self.con.cursor() as cur: try: total_inserted = cur.adbc_ingest( @@ -2588,9 +2652,13 @@ def _create_table_setup(self): column_names_and_types = self._get_column_names_and_types(self._sql_type_name) escape = _get_valid_sqlite_name - create_tbl_stmts = [ - escape(cname) + " " + ctype for cname, ctype, _ in column_names_and_types - ] + create_tbl_stmts = [] + for cname, ctype, _ in column_names_and_types: + col_def = escape(cname) + " " + ctype + if self.nullable is not None and cname in self.nullable: + if not self.nullable[cname]: + col_def += " NOT NULL" + create_tbl_stmts.append(col_def) if self.keys is not None and len(self.keys): if not is_list_like(self.keys): @@ -2810,6 +2878,7 @@ def to_sql( dtype: DtypeArg | None = None, method: Literal["multi"] | Callable | None = None, engine: str = "auto", + nullable: dict[str, bool] | None = None, **engine_kwargs, ) -> int | None: """ @@ -2875,6 +2944,7 @@ def to_sql( if_exists=if_exists, index_label=index_label, dtype=dtype, + nullable=nullable, ) table.create() return table.insert(chunksize, method) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 5865c46b4031e..4aa578001ae86 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -4398,3 +4398,141 @@ def test_xsqlite_if_exists(sqlite_buildin): (5, "E"), ] drop_table(table_name, sqlite_buildin) + + +@pytest.mark.parametrize("conn", ["sqlite_engine"]) +def test_nullable_column(conn, request): + pytest.importorskip("sqlalchemy") + from sqlalchemy import text + + conn = request.getfixturevalue(conn) + + df = DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"], "C": [1.1, 2.2, 3.3]}) + + nullable = {"A": False, "B": True} + + df.to_sql( + "test_nullable", conn, if_exists="replace", index=False, nullable=nullable + ) + + result = pd.read_sql(text("PRAGMA table_info(test_nullable)"), conn) + a_nullable = result[result["name"] == "A"]["notnull"].iloc[0] + b_nullable = result[result["name"] == "B"]["notnull"].iloc[0] + c_nullable = result[result["name"] == "C"]["notnull"].iloc[0] + + assert a_nullable == 1 # NOT NULL + assert b_nullable == 0 # NULL allowed + assert c_nullable == 0 # NULL allowed (default) + + drop_table("test_nullable", conn) + + +@pytest.mark.parametrize("conn", ["sqlite_engine"]) +def test_nullable_with_index(conn, request): + pytest.importorskip("sqlalchemy") + conn = request.getfixturevalue(conn) + + df = DataFrame( + {"A": [1, 2, 3], "B": ["a", "b", "c"]}, + index=Index([10, 20, 30], name="idx"), + ) + + nullable = {"idx": False, "A": False} + + df.to_sql( + "test_nullable_idx", conn, if_exists="replace", index=True, nullable=nullable + ) + + result = pd.read_sql("SELECT * FROM test_nullable_idx", conn) + + assert "idx" in result.columns + assert "A" in result.columns + assert "B" in result.columns + assert len(result) == 3 + + drop_table("test_nullable_idx", conn) + + +def test_nullable_sqlite_builtin(sqlite_buildin): + df = DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + nullable = {"A": False} + + df.to_sql( + "test_nullable_sqlite", + sqlite_buildin, + if_exists="replace", + index=False, + nullable=nullable, + ) + + result = pd.read_sql("PRAGMA table_info(test_nullable_sqlite)", sqlite_buildin) + a_nullable = result[result["name"] == "A"]["notnull"].iloc[0] + b_nullable = result[result["name"] == "B"]["notnull"].iloc[0] + + assert a_nullable == 1 # NOT NULL + assert b_nullable == 0 # NULL allowed + + drop_table("test_nullable_sqlite", sqlite_buildin) + + +@pytest.mark.parametrize("conn", adbc_connectable) +def test_nullable_adbc(conn, request): + pytest.importorskip("pyarrow") + pytest.importorskip("adbc_driver_manager") + + conn = request.getfixturevalue(conn) + + df = DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + nullable = {"A": False, "B": True} + + df.to_sql( + "test_nullable_adbc", conn, if_exists="replace", index=False, nullable=nullable + ) + + result = pd.read_sql("SELECT * FROM test_nullable_adbc", conn) + assert len(result) == 3 + assert list(result.columns) == ["A", "B"] + + drop_table("test_nullable_adbc", conn) + + +@pytest.mark.parametrize("conn", adbc_connectable) +def test_nullable_adbc_with_nulls_raises(conn, request): + pytest.importorskip("pyarrow") + pytest.importorskip("adbc_driver_manager") + + conn = request.getfixturevalue(conn) + + df = DataFrame({"A": [1, None, 3], "B": ["x", "y", "z"]}) + nullable = {"A": False} # A has nulls but we say it shouldn't + + msg = "Column 'A' contains 1 null value\\(s\\) but nullable=False was specified" + with pytest.raises(ValueError, match=msg): + df.to_sql( + "test_table", conn, if_exists="replace", index=False, nullable=nullable + ) + + +@pytest.mark.parametrize("conn", adbc_connectable) +def test_nullable_adbc_append_ignored(conn, request): + pytest.importorskip("pyarrow") + pytest.importorskip("adbc_driver_manager") + + conn = request.getfixturevalue(conn) + + df1 = DataFrame({"A": [1, 2], "B": ["x", "y"]}) + df1.to_sql("test_nullable_append", conn, if_exists="replace", index=False) + + df2 = DataFrame({"A": [3, 4], "B": ["z", "w"]}) + df2.to_sql( + "test_nullable_append", + conn, + if_exists="append", + index=False, + nullable={"A": False}, + ) + + result = pd.read_sql("SELECT * FROM test_nullable_append", conn) + assert len(result) == 4 + + drop_table("test_nullable_append", conn)