Skip to content

Commit 28f4051

Browse files
committed
Add tests to existing test file, test_to_csv.py
1 parent 36a26a9 commit 28f4051

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

pandas/tests/io/formats/test_to_csv.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,3 +741,55 @@ def test_to_csv_iterative_compression_buffer(compression):
741741
pd.read_csv(buffer, compression=compression, index_col=0), df
742742
)
743743
assert not buffer.closed
744+
745+
746+
def test_preserve_numpy_arrays_in_csv(self):
747+
df = pd.DataFrame({
748+
"id": [1, 2],
749+
"embedding": [
750+
np.array([0.1, 0.2, 0.3]),
751+
np.array([0.4, 0.5, 0.6])
752+
],
753+
})
754+
755+
with tm.ensure_clean("test.csv") as path:
756+
df.to_csv(path, index=False, preserve_complex=True)
757+
df_loaded = pd.read_csv(path, preserve_complex=True)
758+
759+
# Validate that embeddings are still NumPy arrays
760+
assert isinstance(df_loaded["embedding"][0], np.ndarray), (
761+
"Test Failed: The CSV did not preserve embeddings as NumPy arrays!"
762+
)
763+
764+
765+
def test_preserve_numpy_arrays_in_csv_empty_dataframe(self):
766+
df = pd.DataFrame({"embedding": []})
767+
expected = """\embedding"""
768+
769+
with tm.ensure_clean("test.csv") as path:
770+
df.to_csv(path, index=False, preserve_complex=True)
771+
with open(path, encoding="utf-8") as f:
772+
result = f.read()
773+
774+
assert result == expected, f"CSV output mismatch for empty DataFrame.\nGot:\n{result}"
775+
776+
777+
def test_preserve_numpy_arrays_in_csv_mixed_dtypes(self):
778+
df = pd.DataFrame({
779+
"id": [101, 102],
780+
"name": ["alice", "bob"],
781+
"scores": [np.array([95.5, 88.0]), np.array([76.0, 90.5])],
782+
"age": [25, 30],
783+
})
784+
785+
with tm.ensure_clean("test.csv") as path:
786+
df.to_csv(path, index=False, preserve_complex=True)
787+
df_loaded = pd.read_csv(path, preserve_complex=True)
788+
789+
assert isinstance(df_loaded["scores"][0], np.ndarray), (
790+
"Failed: 'scores' column not deserialized as np.ndarray."
791+
)
792+
793+
assert df_loaded["id"].dtype == np.int64, "Failed: 'id' should still be int."
794+
assert df_loaded["name"].dtype == object, "Failed: 'name' should still be string/object."
795+
assert df_loaded["age"].dtype == np.int64, "Failed: 'age' should still be int."

0 commit comments

Comments
 (0)