Skip to content

Commit 0a2e9ea

Browse files
Using fixture for catalog
1 parent c25536c commit 0a2e9ea

File tree

1 file changed

+69
-72
lines changed

1 file changed

+69
-72
lines changed

pandas/tests/io/test_iceberg.py

Lines changed: 69 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
data used for Parquet tests (``pandas/tests/io/data/parquet/simple.parquet``).
66
"""
77

8-
from contextlib import contextmanager
8+
import collections
99
import importlib
1010
import pathlib
11-
import tempfile
1211

1312
import pytest
1413

@@ -24,73 +23,75 @@
2423
pq = pytest.importorskip("pyarrow.parquet")
2524

2625

27-
@contextmanager
28-
def create_catalog(catalog_name_in_pyiceberg_config=None):
29-
# the catalog stores the full path of data files, so the catalog needs to be
30-
# created dynamically, and not saved in pandas/tests/io/data as other formats
31-
with tempfile.TemporaryDirectory("-pandas-iceberg.tmp") as catalog_path:
32-
uri = f"sqlite:///{catalog_path}/catalog.sqlite"
33-
warehouse = f"file://{catalog_path}"
34-
catalog = pyiceberg_catalog.load_catalog(
35-
catalog_name_in_pyiceberg_config or "default",
36-
type="sql",
37-
uri=uri,
38-
warehouse=warehouse,
39-
)
40-
catalog.create_namespace("ns")
26+
Catalog = collections.namedtuple("name", "uri")
4127

42-
df = pq.read_table(
43-
pathlib.Path(__file__).parent / "data" / "parquet" / "simple.parquet"
44-
)
45-
table = catalog.create_table("ns.my_table", schema=df.schema)
46-
table.append(df)
4728

48-
if catalog_name_in_pyiceberg_config is not None:
49-
config_path = pathlib.Path.home() / ".pyiceberg.yaml"
50-
with open(config_path, "w", encoding="utf-8") as f:
51-
f.write(f"""\
29+
@pytest.fixture
30+
def catalog(request, tmp_path, params=(None, "default", "pandas_tests")):
31+
# the catalog stores the full path of data files, so the catalog needs to be
32+
# created dynamically, and not saved in pandas/tests/io/data as other formats
33+
catalog_path = tmp_path / "pandas-iceberg-catalog"
34+
catalog_path.mkdir()
35+
catalog_name = request.param
36+
uri = f"sqlite:///{catalog_path}/catalog.sqlite"
37+
warehouse = f"file://{catalog_path}"
38+
catalog = pyiceberg_catalog.load_catalog(
39+
catalog_name,
40+
type="sql",
41+
uri=uri,
42+
warehouse=warehouse,
43+
)
44+
catalog.create_namespace("ns")
45+
46+
df = pq.read_table(
47+
pathlib.Path(__file__).parent / "data" / "parquet" / "simple.parquet"
48+
)
49+
table = catalog.create_table("ns.my_table", schema=df.schema)
50+
table.append(df)
51+
52+
if catalog_name is not None:
53+
config_path = pathlib.Path.home() / ".pyiceberg.yaml"
54+
with open(config_path, "w", encoding="utf-8") as f:
55+
f.write(f"""\
5256
catalog:
53-
{catalog_name_in_pyiceberg_config}:
57+
{catalog_name}:
5458
type: sql
5559
uri: {uri}
5660
warehouse: {warehouse}""")
57-
importlib.reload(pyiceberg_catalog) # needed to reload the config file
5861

59-
try:
60-
yield uri
61-
finally:
62-
if catalog_name_in_pyiceberg_config is not None:
63-
config_path.unlink()
62+
importlib.reload(pyiceberg_catalog) # needed to reload the config file
63+
64+
yield Catalog(name=catalog_name, uri=uri)
65+
66+
if catalog_name is not None:
67+
config_path.unlink()
6468

6569

6670
class TestIceberg:
67-
def test_read(self):
71+
def test_read(self, catalog):
6872
expected = pd.DataFrame(
6973
{
7074
"A": [1, 2, 3],
7175
"B": ["foo", "foo", "foo"],
7276
}
7377
)
74-
with create_catalog() as catalog_uri:
75-
result = read_iceberg(
76-
"ns.my_table",
77-
catalog_properties={"uri": catalog_uri},
78-
)
78+
result = read_iceberg(
79+
"ns.my_table",
80+
catalog_properties={"uri": catalog.uri},
81+
)
7982
tm.assert_frame_equal(result, expected)
8083

81-
@pytest.mark.parametrize("catalog_name", ["default", "pandas_tests"])
82-
def test_read_by_catalog_name(self, catalog_name):
84+
def test_read_by_catalog_name(self, catalog):
8385
expected = pd.DataFrame(
8486
{
8587
"A": [1, 2, 3],
8688
"B": ["foo", "foo", "foo"],
8789
}
8890
)
89-
with create_catalog(catalog_name_in_pyiceberg_config=catalog_name):
90-
result = read_iceberg(
91-
"ns.my_table",
92-
catalog_name=catalog_name,
93-
)
91+
result = read_iceberg(
92+
"ns.my_table",
93+
catalog_name=catalog.name,
94+
)
9495
tm.assert_frame_equal(result, expected)
9596

9697
def test_read_with_row_filter(self):
@@ -100,37 +101,34 @@ def test_read_with_row_filter(self):
100101
"B": ["foo", "foo"],
101102
}
102103
)
103-
with create_catalog() as catalog_uri:
104-
result = read_iceberg(
105-
"ns.my_table",
106-
catalog_properties={"uri": catalog_uri},
107-
row_filter="A > 1",
108-
)
104+
result = read_iceberg(
105+
"ns.my_table",
106+
catalog_properties={"uri": catalog.uri},
107+
row_filter="A > 1",
108+
)
109109
tm.assert_frame_equal(result, expected)
110110

111-
def test_read_with_case_sensitive(self):
111+
def test_read_with_case_sensitive(self, catalog):
112112
expected = pd.DataFrame(
113113
{
114114
"A": [1, 2, 3],
115115
}
116116
)
117-
with create_catalog() as catalog_uri:
118-
result = read_iceberg(
117+
result = read_iceberg(
118+
"ns.my_table",
119+
catalog_properties={"uri": catalog.uri},
120+
selected_fields=["a"],
121+
case_sensitive=False,
122+
)
123+
tm.assert_frame_equal(result, expected)
124+
125+
with pytest.raises(ValueError, match="^Could not find column"):
126+
read_iceberg(
119127
"ns.my_table",
120-
catalog_properties={"uri": catalog_uri},
128+
catalog_properties={"uri": catalog.uri},
121129
selected_fields=["a"],
122-
case_sensitive=False,
130+
case_sensitive=True,
123131
)
124-
tm.assert_frame_equal(result, expected)
125-
126-
with create_catalog() as catalog_uri:
127-
with pytest.raises(ValueError, match="^Could not find column"):
128-
read_iceberg(
129-
"ns.my_table",
130-
catalog_properties={"uri": catalog_uri},
131-
selected_fields=["a"],
132-
case_sensitive=True,
133-
)
134132

135133
def test_read_with_limit(self):
136134
expected = pd.DataFrame(
@@ -139,10 +137,9 @@ def test_read_with_limit(self):
139137
"B": ["foo", "foo"],
140138
}
141139
)
142-
with create_catalog() as catalog_uri:
143-
result = read_iceberg(
144-
"ns.my_table",
145-
catalog_properties={"uri": catalog_uri},
146-
limit=2,
147-
)
140+
result = read_iceberg(
141+
"ns.my_table",
142+
catalog_properties={"uri": catalog.uri},
143+
limit=2,
144+
)
148145
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)