Skip to content

Commit d872245

Browse files
authored
Remove trailing slash from table location when creating a table (#702)
1 parent 74caa17 commit d872245

File tree

8 files changed

+270
-1
lines changed

8 files changed

+270
-1
lines changed

pyiceberg/catalog/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def _get_updated_props_and_update_summary(
779779
def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str:
780780
if not location:
781781
return self._get_default_warehouse_location(database_name, table_name)
782-
return location
782+
return location.rstrip("/")
783783

784784
def _get_default_warehouse_location(self, database_name: str, table_name: str) -> str:
785785
database_properties = self.load_namespace_properties(database_name)

pyiceberg/catalog/rest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ def _create_table(
519519
fresh_sort_order = assign_fresh_sort_order_ids(sort_order, iceberg_schema, fresh_schema)
520520

521521
namespace_and_table = self._split_identifier_for_path(identifier)
522+
if location:
523+
location = location.rstrip("/")
522524
request = CreateTableRequest(
523525
name=namespace_and_table["table"],
524526
location=location,

tests/catalog/test_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def create_table(
105105

106106
if not location:
107107
location = f'{self._warehouse_location}/{"/".join(identifier)}'
108+
location = location.rstrip("/")
108109

109110
metadata_location = self._get_metadata_location(location=location)
110111
metadata = new_table_metadata(
@@ -353,6 +354,19 @@ def test_create_table_location_override(catalog: InMemoryCatalog) -> None:
353354
assert table.location() == new_location
354355

355356

357+
def test_create_table_removes_trailing_slash_from_location(catalog: InMemoryCatalog) -> None:
358+
new_location = f"{catalog._warehouse_location}/new_location"
359+
table = catalog.create_table(
360+
identifier=TEST_TABLE_IDENTIFIER,
361+
schema=TEST_TABLE_SCHEMA,
362+
location=f"{new_location}/",
363+
partition_spec=TEST_TABLE_PARTITION_SPEC,
364+
properties=TEST_TABLE_PROPERTIES,
365+
)
366+
assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table
367+
assert table.location() == new_location
368+
369+
356370
@pytest.mark.parametrize(
357371
"schema,expected",
358372
[

tests/catalog/test_dynamodb.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ def test_create_table_with_given_location(
117117
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
118118

119119

120+
@mock_aws
121+
def test_create_table_removes_trailing_slash_in_location(
122+
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
123+
) -> None:
124+
catalog_name = "test_ddb_catalog"
125+
identifier = (database_name, table_name)
126+
test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url})
127+
test_catalog.create_namespace(namespace=database_name)
128+
location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
129+
table = test_catalog.create_table(identifier=identifier, schema=table_schema_nested, location=f"{location}/")
130+
assert table.identifier == (catalog_name,) + identifier
131+
assert table.location() == location
132+
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
133+
134+
120135
@mock_aws
121136
def test_create_table_with_no_location(
122137
_bucket_initialize: None, table_schema_nested: Schema, database_name: str, table_name: str

tests/catalog/test_glue.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,22 @@ def test_create_table_with_given_location(
137137
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
138138

139139

140+
@mock_aws
141+
def test_create_table_removes_trailing_slash_in_location(
142+
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
143+
) -> None:
144+
catalog_name = "glue"
145+
identifier = (database_name, table_name)
146+
test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url})
147+
test_catalog.create_namespace(namespace=database_name)
148+
location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
149+
table = test_catalog.create_table(identifier=identifier, schema=table_schema_nested, location=f"{location}/")
150+
assert table.identifier == (catalog_name,) + identifier
151+
assert table.location() == location
152+
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
153+
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
154+
155+
140156
@mock_aws
141157
def test_create_table_with_pyarrow_schema(
142158
_bucket_initialize: None,

tests/catalog/test_hive.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,181 @@ def test_create_table(
365365
assert metadata.model_dump() == expected.model_dump()
366366

367367

368+
@pytest.mark.parametrize("hive2_compatible", [True, False])
369+
@patch("time.time", MagicMock(return_value=12345))
370+
def test_create_table_with_given_location_removes_trailing_slash(
371+
table_schema_with_all_types: Schema, hive_database: HiveDatabase, hive_table: HiveTable, hive2_compatible: bool
372+
) -> None:
373+
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
374+
if hive2_compatible:
375+
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, **{"hive.hive2-compatible": "true"})
376+
377+
location = f"{hive_database.locationUri}/table-given-location"
378+
379+
catalog._client = MagicMock()
380+
catalog._client.__enter__().create_table.return_value = None
381+
catalog._client.__enter__().get_table.return_value = hive_table
382+
catalog._client.__enter__().get_database.return_value = hive_database
383+
catalog.create_table(
384+
("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}, location=f"{location}/"
385+
)
386+
387+
called_hive_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
388+
# This one is generated within the function itself, so we need to extract
389+
# it to construct the assert_called_with
390+
metadata_location: str = called_hive_table.parameters["metadata_location"]
391+
assert metadata_location.endswith(".metadata.json")
392+
assert "/database/table-given-location/metadata/" in metadata_location
393+
catalog._client.__enter__().create_table.assert_called_with(
394+
HiveTable(
395+
tableName="table",
396+
dbName="default",
397+
owner="javaberg",
398+
createTime=12345,
399+
lastAccessTime=12345,
400+
retention=None,
401+
sd=StorageDescriptor(
402+
cols=[
403+
FieldSchema(name='boolean', type='boolean', comment=None),
404+
FieldSchema(name='integer', type='int', comment=None),
405+
FieldSchema(name='long', type='bigint', comment=None),
406+
FieldSchema(name='float', type='float', comment=None),
407+
FieldSchema(name='double', type='double', comment=None),
408+
FieldSchema(name='decimal', type='decimal(32,3)', comment=None),
409+
FieldSchema(name='date', type='date', comment=None),
410+
FieldSchema(name='time', type='string', comment=None),
411+
FieldSchema(name='timestamp', type='timestamp', comment=None),
412+
FieldSchema(
413+
name='timestamptz',
414+
type='timestamp' if hive2_compatible else 'timestamp with local time zone',
415+
comment=None,
416+
),
417+
FieldSchema(name='string', type='string', comment=None),
418+
FieldSchema(name='uuid', type='string', comment=None),
419+
FieldSchema(name='fixed', type='binary', comment=None),
420+
FieldSchema(name='binary', type='binary', comment=None),
421+
FieldSchema(name='list', type='array<string>', comment=None),
422+
FieldSchema(name='map', type='map<string,int>', comment=None),
423+
FieldSchema(name='struct', type='struct<inner_string:string,inner_int:int>', comment=None),
424+
],
425+
location=f"{hive_database.locationUri}/table-given-location",
426+
inputFormat="org.apache.hadoop.mapred.FileInputFormat",
427+
outputFormat="org.apache.hadoop.mapred.FileOutputFormat",
428+
compressed=None,
429+
numBuckets=None,
430+
serdeInfo=SerDeInfo(
431+
name=None,
432+
serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
433+
parameters=None,
434+
description=None,
435+
serializerClass=None,
436+
deserializerClass=None,
437+
serdeType=None,
438+
),
439+
bucketCols=None,
440+
sortCols=None,
441+
parameters=None,
442+
skewedInfo=None,
443+
storedAsSubDirectories=None,
444+
),
445+
partitionKeys=None,
446+
parameters={"EXTERNAL": "TRUE", "table_type": "ICEBERG", "metadata_location": metadata_location},
447+
viewOriginalText=None,
448+
viewExpandedText=None,
449+
tableType="EXTERNAL_TABLE",
450+
privileges=None,
451+
temporary=False,
452+
rewriteEnabled=None,
453+
creationMetadata=None,
454+
catName=None,
455+
ownerType=1,
456+
writeId=-1,
457+
isStatsCompliant=None,
458+
colStats=None,
459+
accessType=None,
460+
requiredReadCapabilities=None,
461+
requiredWriteCapabilities=None,
462+
id=None,
463+
fileMetadata=None,
464+
dictionary=None,
465+
txnId=None,
466+
)
467+
)
468+
469+
with open(metadata_location, encoding=UTF8) as f:
470+
payload = f.read()
471+
472+
metadata = TableMetadataUtil.parse_raw(payload)
473+
474+
assert "database/table-given-location" in metadata.location
475+
476+
expected = TableMetadataV2(
477+
location=metadata.location,
478+
table_uuid=metadata.table_uuid,
479+
last_updated_ms=metadata.last_updated_ms,
480+
last_column_id=22,
481+
schemas=[
482+
Schema(
483+
NestedField(field_id=1, name='boolean', field_type=BooleanType(), required=True),
484+
NestedField(field_id=2, name='integer', field_type=IntegerType(), required=True),
485+
NestedField(field_id=3, name='long', field_type=LongType(), required=True),
486+
NestedField(field_id=4, name='float', field_type=FloatType(), required=True),
487+
NestedField(field_id=5, name='double', field_type=DoubleType(), required=True),
488+
NestedField(field_id=6, name='decimal', field_type=DecimalType(precision=32, scale=3), required=True),
489+
NestedField(field_id=7, name='date', field_type=DateType(), required=True),
490+
NestedField(field_id=8, name='time', field_type=TimeType(), required=True),
491+
NestedField(field_id=9, name='timestamp', field_type=TimestampType(), required=True),
492+
NestedField(field_id=10, name='timestamptz', field_type=TimestamptzType(), required=True),
493+
NestedField(field_id=11, name='string', field_type=StringType(), required=True),
494+
NestedField(field_id=12, name='uuid', field_type=UUIDType(), required=True),
495+
NestedField(field_id=13, name='fixed', field_type=FixedType(length=12), required=True),
496+
NestedField(field_id=14, name='binary', field_type=BinaryType(), required=True),
497+
NestedField(
498+
field_id=15,
499+
name='list',
500+
field_type=ListType(type='list', element_id=18, element_type=StringType(), element_required=True),
501+
required=True,
502+
),
503+
NestedField(
504+
field_id=16,
505+
name='map',
506+
field_type=MapType(
507+
type='map', key_id=19, key_type=StringType(), value_id=20, value_type=IntegerType(), value_required=True
508+
),
509+
required=True,
510+
),
511+
NestedField(
512+
field_id=17,
513+
name='struct',
514+
field_type=StructType(
515+
NestedField(field_id=21, name='inner_string', field_type=StringType(), required=False),
516+
NestedField(field_id=22, name='inner_int', field_type=IntegerType(), required=True),
517+
),
518+
required=False,
519+
),
520+
schema_id=0,
521+
identifier_field_ids=[2],
522+
)
523+
],
524+
current_schema_id=0,
525+
last_partition_id=999,
526+
properties={"owner": "javaberg", 'write.parquet.compression-codec': 'zstd'},
527+
partition_specs=[PartitionSpec()],
528+
default_spec_id=0,
529+
current_snapshot_id=None,
530+
snapshots=[],
531+
snapshot_log=[],
532+
metadata_log=[],
533+
sort_orders=[SortOrder(order_id=0)],
534+
default_sort_order_id=0,
535+
refs={},
536+
format_version=2,
537+
last_sequence_number=0,
538+
)
539+
540+
assert metadata.model_dump() == expected.model_dump()
541+
542+
368543
@patch("time.time", MagicMock(return_value=12345))
369544
def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
370545
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

tests/catalog/test_rest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,31 @@ def test_create_table_200(
732732
assert actual == expected
733733

734734

735+
def test_create_table_with_given_location_removes_trailing_slash_200(
736+
rest_mock: Mocker, table_schema_simple: Schema, example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any]
737+
) -> None:
738+
rest_mock.post(
739+
f"{TEST_URI}v1/namespaces/fokko/tables",
740+
json=example_table_metadata_no_snapshot_v1_rest_json,
741+
status_code=200,
742+
request_headers=TEST_HEADERS,
743+
)
744+
catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN)
745+
location = "s3://warehouse/database/table-custom-location"
746+
catalog.create_table(
747+
identifier=("fokko", "fokko2"),
748+
schema=table_schema_simple,
749+
location=f"{location}/",
750+
partition_spec=PartitionSpec(
751+
PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id"), spec_id=1
752+
),
753+
sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())),
754+
properties={"owner": "fokko"},
755+
)
756+
assert rest_mock.last_request
757+
assert rest_mock.last_request.json()["location"] == location
758+
759+
735760
def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> None:
736761
rest_mock.post(
737762
f"{TEST_URI}v1/namespaces/fokko/tables",

tests/catalog/test_sql.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,28 @@ def test_create_table_with_default_warehouse_location(
264264
catalog.drop_table(random_identifier)
265265

266266

267+
@pytest.mark.parametrize(
268+
'catalog',
269+
[
270+
lazy_fixture('catalog_memory'),
271+
lazy_fixture('catalog_sqlite'),
272+
],
273+
)
274+
def test_create_table_with_given_location_removes_trailing_slash(
275+
warehouse: Path, catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier
276+
) -> None:
277+
database_name, table_name = random_identifier
278+
location = f"file://{warehouse}/{database_name}.db/{table_name}-given"
279+
catalog.create_namespace(database_name)
280+
catalog.create_table(random_identifier, table_schema_nested, location=f"{location}/")
281+
table = catalog.load_table(random_identifier)
282+
assert table.identifier == (catalog.name,) + random_identifier
283+
assert table.metadata_location.startswith(f"file://{warehouse}")
284+
assert os.path.exists(table.metadata_location[len("file://") :])
285+
assert table.location() == location
286+
catalog.drop_table(random_identifier)
287+
288+
267289
@pytest.mark.parametrize(
268290
'catalog',
269291
[

0 commit comments

Comments
 (0)