Skip to content

Commit 8014b6c

Browse files
authored
Upsert: Reuse existing expression to detect rows to be inserted (#1662)
Also slight refactor of the tests to bring it more in line with the rest
1 parent ee11bb0 commit 8014b6c

File tree

3 files changed

+66
-74
lines changed

3 files changed

+66
-74
lines changed

pyiceberg/table/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
manifest_evaluator,
6363
)
6464
from pyiceberg.io import FileIO, load_file_io
65-
from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow
65+
from pyiceberg.io.pyarrow import ArrowScan, expression_to_pyarrow, schema_to_pyarrow
6666
from pyiceberg.manifest import (
6767
POSITIONAL_DELETE_SCHEMA,
6868
DataFile,
@@ -1101,7 +1101,12 @@ def name_mapping(self) -> Optional[NameMapping]:
11011101
return self.metadata.name_mapping()
11021102

11031103
def upsert(
1104-
self, df: pa.Table, join_cols: list[str], when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True
1104+
self,
1105+
df: pa.Table,
1106+
join_cols: list[str],
1107+
when_matched_update_all: bool = True,
1108+
when_not_matched_insert_all: bool = True,
1109+
case_sensitive: bool = True,
11051110
) -> UpsertResult:
11061111
"""Shorthand API for performing an upsert to an iceberg table.
11071112
@@ -1111,6 +1116,7 @@ def upsert(
11111116
join_cols: The columns to join on. These are essentially analogous to primary keys
11121117
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
11131118
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
1119+
case_sensitive: Bool indicating if the match should be case-sensitive
11141120
11151121
Example Use Cases:
11161122
Case 1: Both Parameters = True (Full Upsert)
@@ -1144,7 +1150,7 @@ def upsert(
11441150

11451151
# get list of rows that exist so we don't have to load the entire target table
11461152
matched_predicate = upsert_util.create_match_filter(df, join_cols)
1147-
matched_iceberg_table = self.scan(row_filter=matched_predicate).to_arrow()
1153+
matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()
11481154

11491155
update_row_cnt = 0
11501156
insert_row_cnt = 0
@@ -1164,7 +1170,10 @@ def upsert(
11641170
tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
11651171

11661172
if when_not_matched_insert_all:
1167-
rows_to_insert = upsert_util.get_rows_to_insert(df, matched_iceberg_table, join_cols)
1173+
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
1174+
expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive)
1175+
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
1176+
rows_to_insert = df.filter(~expr_match_arrow)
11681177

11691178
insert_row_cnt = len(rows_to_insert)
11701179

pyiceberg/table/upsert_util.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,27 +92,3 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
9292
rows_to_update_table = rows_to_update_table.select(list(common_columns))
9393

9494
return rows_to_update_table
95-
96-
97-
def get_rows_to_insert(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
98-
source_filter_expr = pc.scalar(True)
99-
100-
for col in join_cols:
101-
target_values = target_table.column(col).to_pylist()
102-
expr = pc.field(col).isin(target_values)
103-
104-
if source_filter_expr is None:
105-
source_filter_expr = expr
106-
else:
107-
source_filter_expr = source_filter_expr & expr
108-
109-
non_matching_expr = ~source_filter_expr
110-
111-
source_columns = set(source_table.column_names)
112-
target_columns = set(target_table.column_names)
113-
114-
common_columns = source_columns.intersection(target_columns)
115-
116-
non_matching_rows = source_table.filter(non_matching_expr).select(common_columns)
117-
118-
return non_matching_rows

tests/table/test_upsert.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,30 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
from pathlib import PosixPath
18+
1719
import pytest
1820
from datafusion import SessionContext
1921
from pyarrow import Table as pa_table
2022

23+
from pyiceberg.catalog import Catalog
24+
from pyiceberg.exceptions import NoSuchTableError
2125
from pyiceberg.table import UpsertResult
2226
from tests.catalog.test_base import InMemoryCatalog, Table
2327

24-
_TEST_NAMESPACE = "test_ns"
28+
29+
@pytest.fixture
30+
def catalog(tmp_path: PosixPath) -> InMemoryCatalog:
31+
catalog = InMemoryCatalog("test.in_memory.catalog", warehouse=tmp_path.absolute().as_posix())
32+
catalog.create_namespace("default")
33+
return catalog
34+
35+
36+
def _drop_table(catalog: Catalog, identifier: str) -> None:
37+
try:
38+
catalog.drop_table(identifier)
39+
except NoSuchTableError:
40+
pass
2541

2642

2743
def show_iceberg_table(table: Table, ctx: SessionContext) -> None:
@@ -72,7 +88,7 @@ def gen_source_dataset(start_row: int, end_row: int, composite_key: bool, add_du
7288

7389

7490
def gen_target_iceberg_table(
75-
start_row: int, end_row: int, composite_key: bool, ctx: SessionContext, catalog: InMemoryCatalog, namespace: str
91+
start_row: int, end_row: int, composite_key: bool, ctx: SessionContext, catalog: InMemoryCatalog, identifier: str
7692
) -> Table:
7793
additional_columns = ", t.order_id + 1000 as order_line_id" if composite_key else ""
7894

@@ -83,7 +99,7 @@ def gen_target_iceberg_table(
8399
from t
84100
""").to_arrow_table()
85101

86-
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
102+
table = catalog.create_table(identifier, df.schema)
87103

88104
table.append(df)
89105

@@ -95,13 +111,6 @@ def assert_upsert_result(res: UpsertResult, expected_updated: int, expected_inse
95111
assert res.rows_inserted == expected_inserted, f"rows inserted should be {expected_inserted}, but got {res.rows_inserted}"
96112

97113

98-
@pytest.fixture(scope="session")
99-
def catalog_conn() -> InMemoryCatalog:
100-
catalog = InMemoryCatalog("test")
101-
catalog.create_namespace(namespace=_TEST_NAMESPACE)
102-
yield catalog
103-
104-
105114
@pytest.mark.parametrize(
106115
"join_cols, src_start_row, src_end_row, target_start_row, target_end_row, when_matched_update_all, when_not_matched_insert_all, expected_updated, expected_inserted",
107116
[
@@ -112,7 +121,7 @@ def catalog_conn() -> InMemoryCatalog:
112121
],
113122
)
114123
def test_merge_rows(
115-
catalog_conn: InMemoryCatalog,
124+
catalog: Catalog,
116125
join_cols: list[str],
117126
src_start_row: int,
118127
src_end_row: int,
@@ -123,12 +132,13 @@ def test_merge_rows(
123132
expected_updated: int,
124133
expected_inserted: int,
125134
) -> None:
126-
ctx = SessionContext()
135+
identifier = "default.test_merge_rows"
136+
_drop_table(catalog, identifier)
127137

128-
catalog = catalog_conn
138+
ctx = SessionContext()
129139

130140
source_df = gen_source_dataset(src_start_row, src_end_row, False, False, ctx)
131-
ice_table = gen_target_iceberg_table(target_start_row, target_end_row, False, ctx, catalog, _TEST_NAMESPACE)
141+
ice_table = gen_target_iceberg_table(target_start_row, target_end_row, False, ctx, catalog, identifier)
132142
res = ice_table.upsert(
133143
df=source_df,
134144
join_cols=join_cols,
@@ -138,13 +148,13 @@ def test_merge_rows(
138148

139149
assert_upsert_result(res, expected_updated, expected_inserted)
140150

141-
catalog.drop_table(f"{_TEST_NAMESPACE}.target")
142151

143-
144-
def test_merge_scenario_skip_upd_row(catalog_conn: InMemoryCatalog) -> None:
152+
def test_merge_scenario_skip_upd_row(catalog: Catalog) -> None:
145153
"""
146154
tests a single insert and update; skips a row that does not need to be updated
147155
"""
156+
identifier = "default.test_merge_scenario_skip_upd_row"
157+
_drop_table(catalog, identifier)
148158

149159
ctx = SessionContext()
150160

@@ -154,8 +164,7 @@ def test_merge_scenario_skip_upd_row(catalog_conn: InMemoryCatalog) -> None:
154164
select 2 as order_id, date '2021-01-01' as order_date, 'A' as order_type
155165
""").to_arrow_table()
156166

157-
catalog = catalog_conn
158-
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
167+
table = catalog.create_table(identifier, df.schema)
159168

160169
table.append(df)
161170

@@ -174,24 +183,24 @@ def test_merge_scenario_skip_upd_row(catalog_conn: InMemoryCatalog) -> None:
174183

175184
assert_upsert_result(res, expected_updated, expected_inserted)
176185

177-
catalog.drop_table(f"{_TEST_NAMESPACE}.target")
178-
179186

180-
def test_merge_scenario_date_as_key(catalog_conn: InMemoryCatalog) -> None:
187+
def test_merge_scenario_date_as_key(catalog: Catalog) -> None:
181188
"""
182189
tests a single insert and update; primary key is a date column
183190
"""
184191

185192
ctx = SessionContext()
186193

194+
identifier = "default.test_merge_scenario_date_as_key"
195+
_drop_table(catalog, identifier)
196+
187197
df = ctx.sql("""
188198
select date '2021-01-01' as order_date, 'A' as order_type
189199
union all
190200
select date '2021-01-02' as order_date, 'A' as order_type
191201
""").to_arrow_table()
192202

193-
catalog = catalog_conn
194-
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
203+
table = catalog.create_table(identifier, df.schema)
195204

196205
table.append(df)
197206

@@ -210,14 +219,15 @@ def test_merge_scenario_date_as_key(catalog_conn: InMemoryCatalog) -> None:
210219

211220
assert_upsert_result(res, expected_updated, expected_inserted)
212221

213-
catalog.drop_table(f"{_TEST_NAMESPACE}.target")
214-
215222

216-
def test_merge_scenario_string_as_key(catalog_conn: InMemoryCatalog) -> None:
223+
def test_merge_scenario_string_as_key(catalog: Catalog) -> None:
217224
"""
218225
tests a single insert and update; primary key is a string column
219226
"""
220227

228+
identifier = "default.test_merge_scenario_string_as_key"
229+
_drop_table(catalog, identifier)
230+
221231
ctx = SessionContext()
222232

223233
df = ctx.sql("""
@@ -226,8 +236,7 @@ def test_merge_scenario_string_as_key(catalog_conn: InMemoryCatalog) -> None:
226236
select 'def' as order_id, 'A' as order_type
227237
""").to_arrow_table()
228238

229-
catalog = catalog_conn
230-
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
239+
table = catalog.create_table(identifier, df.schema)
231240

232241
table.append(df)
233242

@@ -246,18 +255,18 @@ def test_merge_scenario_string_as_key(catalog_conn: InMemoryCatalog) -> None:
246255

247256
assert_upsert_result(res, expected_updated, expected_inserted)
248257

249-
catalog.drop_table(f"{_TEST_NAMESPACE}.target")
250258

251-
252-
def test_merge_scenario_composite_key(catalog_conn: InMemoryCatalog) -> None:
259+
def test_merge_scenario_composite_key(catalog: Catalog) -> None:
253260
"""
254261
tests merging 200 rows with a composite key
255262
"""
256263

264+
identifier = "default.test_merge_scenario_composite_key"
265+
_drop_table(catalog, identifier)
266+
257267
ctx = SessionContext()
258268

259-
catalog = catalog_conn
260-
table = gen_target_iceberg_table(1, 200, True, ctx, catalog, _TEST_NAMESPACE)
269+
table = gen_target_iceberg_table(1, 200, True, ctx, catalog, identifier)
261270
source_df = gen_source_dataset(101, 300, True, False, ctx)
262271

263272
res = table.upsert(df=source_df, join_cols=["order_id", "order_line_id"])
@@ -267,43 +276,41 @@ def test_merge_scenario_composite_key(catalog_conn: InMemoryCatalog) -> None:
267276

268277
assert_upsert_result(res, expected_updated, expected_inserted)
269278

270-
catalog.drop_table(f"{_TEST_NAMESPACE}.target")
271279

272-
273-
def test_merge_source_dups(catalog_conn: InMemoryCatalog) -> None:
280+
def test_merge_source_dups(catalog: Catalog) -> None:
274281
"""
275282
tests duplicate rows in source
276283
"""
277284

285+
identifier = "default.test_merge_source_dups"
286+
_drop_table(catalog, identifier)
287+
278288
ctx = SessionContext()
279289

280-
catalog = catalog_conn
281-
table = gen_target_iceberg_table(1, 10, False, ctx, catalog, _TEST_NAMESPACE)
290+
table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier)
282291
source_df = gen_source_dataset(5, 15, False, True, ctx)
283292

284293
with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"):
285294
table.upsert(df=source_df, join_cols=["order_id"])
286295

287-
catalog.drop_table(f"{_TEST_NAMESPACE}.target")
288-
289296

290-
def test_key_cols_misaligned(catalog_conn: InMemoryCatalog) -> None:
297+
def test_key_cols_misaligned(catalog: Catalog) -> None:
291298
"""
292299
tests join columns missing from one of the tables
293300
"""
294301

302+
identifier = "default.test_key_cols_misaligned"
303+
_drop_table(catalog, identifier)
304+
295305
ctx = SessionContext()
296306

297307
df = ctx.sql("select 1 as order_id, date '2021-01-01' as order_date, 'A' as order_type").to_arrow_table()
298308

299-
catalog = catalog_conn
300-
table = catalog.create_table(f"{_TEST_NAMESPACE}.target", df.schema)
309+
table = catalog.create_table(identifier, df.schema)
301310

302311
table.append(df)
303312

304313
df_src = ctx.sql("select 1 as item_id, date '2021-05-01' as order_date, 'B' as order_type").to_arrow_table()
305314

306315
with pytest.raises(Exception, match=r"""Field ".*" does not exist in schema"""):
307316
table.upsert(df=df_src, join_cols=["order_id"])
308-
309-
catalog.drop_table(f"{_TEST_NAMESPACE}.target")

0 commit comments

Comments
 (0)