1
- import os
2
1
import tempfile
3
2
4
3
import numpy as np
4
+
5
5
import pandas as pd
6
6
7
7
@@ -11,42 +11,39 @@ def test_preserve_numpy_arrays_in_csv():
11
11
"id" : [1 , 2 ],
12
12
"embedding" : [
13
13
np .array ([0.1 , 0.2 , 0.3 ]),
14
- np .array ([0.4 , 0.5 , 0.6 ])
14
+ np .array ([0.4 , 0.5 , 0.6 ]),
15
15
],
16
16
})
17
17
18
- with tempfile .NamedTemporaryFile (delete = False , suffix = ".csv" ) as tmp :
18
+ with tempfile .NamedTemporaryFile (suffix = ".csv" ) as tmp :
19
19
path = tmp .name
20
-
21
- try :
22
20
df .to_csv (path , index = False , preserve_complex = True )
23
21
df_loaded = pd .read_csv (path , preserve_complex = True )
24
- assert isinstance ( df_loaded [ "embedding" ][ 0 ], np . ndarray ), (
25
- "Test Failed: The CSV did not preserve embeddings as NumPy arrays!"
26
- )
27
- print ( "PASS: test_preserve_numpy_arrays_in_csv" )
28
- finally :
29
- os . remove ( path )
22
+
23
+ assert isinstance (
24
+ df_loaded [ "embedding" ][ 0 ], np . ndarray
25
+ ), "Test Failed: The CSV did not preserve embeddings as NumPy arrays!"
26
+
27
+ print ( "PASS: test_preserve_numpy_arrays_in_csv" )
30
28
31
29
32
30
def test_preserve_numpy_arrays_in_csv_empty_dataframe ():
33
31
print ("\n Running: test_preserve_numpy_arrays_in_csv_empty_dataframe" )
34
32
df = pd .DataFrame ({"embedding" : []})
35
33
expected = "embedding\n "
36
34
37
- with tempfile .NamedTemporaryFile (delete = False , suffix = ".csv" ) as tmp :
35
+ with tempfile .NamedTemporaryFile (suffix = ".csv" ) as tmp :
38
36
path = tmp .name
39
-
40
- try :
41
37
df .to_csv (path , index = False , preserve_complex = True )
42
38
with open (path , encoding = "utf-8" ) as f :
43
39
result = f .read ()
44
- assert result == expected , (
45
- f"CSV output mismatch for empty DataFrame.\n Got:\n { result } \n Expected:\n { expected } "
46
- )
47
- print ("PASS: test_preserve_numpy_arrays_in_csv_empty_dataframe" )
48
- finally :
49
- os .remove (path )
40
+
41
+ msg = (
42
+ f"CSV output mismatch for empty DataFrame.\n "
43
+ f"Got:\n { result } \n Expected:\n { expected } "
44
+ )
45
+ assert result == expected , msg
46
+ print ("PASS: test_preserve_numpy_arrays_in_csv_empty_dataframe" )
50
47
51
48
52
49
def test_preserve_numpy_arrays_in_csv_mixed_dtypes ():
@@ -56,30 +53,33 @@ def test_preserve_numpy_arrays_in_csv_mixed_dtypes():
56
53
"name" : ["alice" , "bob" ],
57
54
"scores" : [
58
55
np .array ([95.5 , 88.0 ]),
59
- np .array ([76.0 , 90.5 ])
56
+ np .array ([76.0 , 90.5 ]),
60
57
],
61
58
"age" : [25 , 30 ],
62
59
})
63
60
64
- with tempfile .NamedTemporaryFile (delete = False , suffix = ".csv" ) as tmp :
61
+ with tempfile .NamedTemporaryFile (suffix = ".csv" ) as tmp :
65
62
path = tmp .name
66
-
67
- try :
68
63
df .to_csv (path , index = False , preserve_complex = True )
69
64
df_loaded = pd .read_csv (path , preserve_complex = True )
70
- assert isinstance (df_loaded ["scores" ][0 ], np .ndarray ), (
71
- "Failed: 'scores' column not deserialized as np.ndarray."
65
+
66
+ err_scores = "Failed: 'scores' column not deserialized as np.ndarray."
67
+ assert isinstance (df_loaded ["scores" ][0 ], np .ndarray ), err_scores
68
+ assert df_loaded ["id" ].dtype == np .int64 , (
69
+ "Failed: 'id' should still be int."
70
+ )
71
+ assert df_loaded ["name" ].dtype == object , (
72
+ "Failed: 'name' should still be object/string."
73
+ )
74
+ assert df_loaded ["age" ].dtype == np .int64 , (
75
+ "Failed: 'age' should still be int."
72
76
)
73
- assert df_loaded ["id" ].dtype == np .int64 , "Failed: 'id' should still be int."
74
- assert df_loaded ["name" ].dtype == object , "Failed: 'name' should still be object/string."
75
- assert df_loaded ["age" ].dtype == np .int64 , "Failed: 'age' should still be int."
76
77
77
- print ("PASS: test_preserve_numpy_arrays_in_csv_mixed_dtypes" )
78
- finally :
79
- os .remove (path )
78
+ print ("PASS: test_preserve_numpy_arrays_in_csv_mixed_dtypes" )
80
79
81
80
82
81
if __name__ == "__main__" :
83
82
test_preserve_numpy_arrays_in_csv ()
84
83
test_preserve_numpy_arrays_in_csv_empty_dataframe ()
85
84
test_preserve_numpy_arrays_in_csv_mixed_dtypes ()
85
+ print ("\n Done." )
0 commit comments