@@ -365,6 +365,181 @@ def test_create_table(
365
365
assert metadata .model_dump () == expected .model_dump ()
366
366
367
367
368
+ @pytest .mark .parametrize ("hive2_compatible" , [True , False ])
369
+ @patch ("time.time" , MagicMock (return_value = 12345 ))
370
+ def test_create_table_with_given_location_removes_trailing_slash (
371
+ table_schema_with_all_types : Schema , hive_database : HiveDatabase , hive_table : HiveTable , hive2_compatible : bool
372
+ ) -> None :
373
+ catalog = HiveCatalog (HIVE_CATALOG_NAME , uri = HIVE_METASTORE_FAKE_URL )
374
+ if hive2_compatible :
375
+ catalog = HiveCatalog (HIVE_CATALOG_NAME , uri = HIVE_METASTORE_FAKE_URL , ** {"hive.hive2-compatible" : "true" })
376
+
377
+ location = f"{ hive_database .locationUri } /table-given-location"
378
+
379
+ catalog ._client = MagicMock ()
380
+ catalog ._client .__enter__ ().create_table .return_value = None
381
+ catalog ._client .__enter__ ().get_table .return_value = hive_table
382
+ catalog ._client .__enter__ ().get_database .return_value = hive_database
383
+ catalog .create_table (
384
+ ("default" , "table" ), schema = table_schema_with_all_types , properties = {"owner" : "javaberg" }, location = f"{ location } /"
385
+ )
386
+
387
+ called_hive_table : HiveTable = catalog ._client .__enter__ ().create_table .call_args [0 ][0 ]
388
+ # This one is generated within the function itself, so we need to extract
389
+ # it to construct the assert_called_with
390
+ metadata_location : str = called_hive_table .parameters ["metadata_location" ]
391
+ assert metadata_location .endswith (".metadata.json" )
392
+ assert "/database/table-given-location/metadata/" in metadata_location
393
+ catalog ._client .__enter__ ().create_table .assert_called_with (
394
+ HiveTable (
395
+ tableName = "table" ,
396
+ dbName = "default" ,
397
+ owner = "javaberg" ,
398
+ createTime = 12345 ,
399
+ lastAccessTime = 12345 ,
400
+ retention = None ,
401
+ sd = StorageDescriptor (
402
+ cols = [
403
+ FieldSchema (name = 'boolean' , type = 'boolean' , comment = None ),
404
+ FieldSchema (name = 'integer' , type = 'int' , comment = None ),
405
+ FieldSchema (name = 'long' , type = 'bigint' , comment = None ),
406
+ FieldSchema (name = 'float' , type = 'float' , comment = None ),
407
+ FieldSchema (name = 'double' , type = 'double' , comment = None ),
408
+ FieldSchema (name = 'decimal' , type = 'decimal(32,3)' , comment = None ),
409
+ FieldSchema (name = 'date' , type = 'date' , comment = None ),
410
+ FieldSchema (name = 'time' , type = 'string' , comment = None ),
411
+ FieldSchema (name = 'timestamp' , type = 'timestamp' , comment = None ),
412
+ FieldSchema (
413
+ name = 'timestamptz' ,
414
+ type = 'timestamp' if hive2_compatible else 'timestamp with local time zone' ,
415
+ comment = None ,
416
+ ),
417
+ FieldSchema (name = 'string' , type = 'string' , comment = None ),
418
+ FieldSchema (name = 'uuid' , type = 'string' , comment = None ),
419
+ FieldSchema (name = 'fixed' , type = 'binary' , comment = None ),
420
+ FieldSchema (name = 'binary' , type = 'binary' , comment = None ),
421
+ FieldSchema (name = 'list' , type = 'array<string>' , comment = None ),
422
+ FieldSchema (name = 'map' , type = 'map<string,int>' , comment = None ),
423
+ FieldSchema (name = 'struct' , type = 'struct<inner_string:string,inner_int:int>' , comment = None ),
424
+ ],
425
+ location = f"{ hive_database .locationUri } /table-given-location" ,
426
+ inputFormat = "org.apache.hadoop.mapred.FileInputFormat" ,
427
+ outputFormat = "org.apache.hadoop.mapred.FileOutputFormat" ,
428
+ compressed = None ,
429
+ numBuckets = None ,
430
+ serdeInfo = SerDeInfo (
431
+ name = None ,
432
+ serializationLib = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" ,
433
+ parameters = None ,
434
+ description = None ,
435
+ serializerClass = None ,
436
+ deserializerClass = None ,
437
+ serdeType = None ,
438
+ ),
439
+ bucketCols = None ,
440
+ sortCols = None ,
441
+ parameters = None ,
442
+ skewedInfo = None ,
443
+ storedAsSubDirectories = None ,
444
+ ),
445
+ partitionKeys = None ,
446
+ parameters = {"EXTERNAL" : "TRUE" , "table_type" : "ICEBERG" , "metadata_location" : metadata_location },
447
+ viewOriginalText = None ,
448
+ viewExpandedText = None ,
449
+ tableType = "EXTERNAL_TABLE" ,
450
+ privileges = None ,
451
+ temporary = False ,
452
+ rewriteEnabled = None ,
453
+ creationMetadata = None ,
454
+ catName = None ,
455
+ ownerType = 1 ,
456
+ writeId = - 1 ,
457
+ isStatsCompliant = None ,
458
+ colStats = None ,
459
+ accessType = None ,
460
+ requiredReadCapabilities = None ,
461
+ requiredWriteCapabilities = None ,
462
+ id = None ,
463
+ fileMetadata = None ,
464
+ dictionary = None ,
465
+ txnId = None ,
466
+ )
467
+ )
468
+
469
+ with open (metadata_location , encoding = UTF8 ) as f :
470
+ payload = f .read ()
471
+
472
+ metadata = TableMetadataUtil .parse_raw (payload )
473
+
474
+ assert "database/table-given-location" in metadata .location
475
+
476
+ expected = TableMetadataV2 (
477
+ location = metadata .location ,
478
+ table_uuid = metadata .table_uuid ,
479
+ last_updated_ms = metadata .last_updated_ms ,
480
+ last_column_id = 22 ,
481
+ schemas = [
482
+ Schema (
483
+ NestedField (field_id = 1 , name = 'boolean' , field_type = BooleanType (), required = True ),
484
+ NestedField (field_id = 2 , name = 'integer' , field_type = IntegerType (), required = True ),
485
+ NestedField (field_id = 3 , name = 'long' , field_type = LongType (), required = True ),
486
+ NestedField (field_id = 4 , name = 'float' , field_type = FloatType (), required = True ),
487
+ NestedField (field_id = 5 , name = 'double' , field_type = DoubleType (), required = True ),
488
+ NestedField (field_id = 6 , name = 'decimal' , field_type = DecimalType (precision = 32 , scale = 3 ), required = True ),
489
+ NestedField (field_id = 7 , name = 'date' , field_type = DateType (), required = True ),
490
+ NestedField (field_id = 8 , name = 'time' , field_type = TimeType (), required = True ),
491
+ NestedField (field_id = 9 , name = 'timestamp' , field_type = TimestampType (), required = True ),
492
+ NestedField (field_id = 10 , name = 'timestamptz' , field_type = TimestamptzType (), required = True ),
493
+ NestedField (field_id = 11 , name = 'string' , field_type = StringType (), required = True ),
494
+ NestedField (field_id = 12 , name = 'uuid' , field_type = UUIDType (), required = True ),
495
+ NestedField (field_id = 13 , name = 'fixed' , field_type = FixedType (length = 12 ), required = True ),
496
+ NestedField (field_id = 14 , name = 'binary' , field_type = BinaryType (), required = True ),
497
+ NestedField (
498
+ field_id = 15 ,
499
+ name = 'list' ,
500
+ field_type = ListType (type = 'list' , element_id = 18 , element_type = StringType (), element_required = True ),
501
+ required = True ,
502
+ ),
503
+ NestedField (
504
+ field_id = 16 ,
505
+ name = 'map' ,
506
+ field_type = MapType (
507
+ type = 'map' , key_id = 19 , key_type = StringType (), value_id = 20 , value_type = IntegerType (), value_required = True
508
+ ),
509
+ required = True ,
510
+ ),
511
+ NestedField (
512
+ field_id = 17 ,
513
+ name = 'struct' ,
514
+ field_type = StructType (
515
+ NestedField (field_id = 21 , name = 'inner_string' , field_type = StringType (), required = False ),
516
+ NestedField (field_id = 22 , name = 'inner_int' , field_type = IntegerType (), required = True ),
517
+ ),
518
+ required = False ,
519
+ ),
520
+ schema_id = 0 ,
521
+ identifier_field_ids = [2 ],
522
+ )
523
+ ],
524
+ current_schema_id = 0 ,
525
+ last_partition_id = 999 ,
526
+ properties = {"owner" : "javaberg" , 'write.parquet.compression-codec' : 'zstd' },
527
+ partition_specs = [PartitionSpec ()],
528
+ default_spec_id = 0 ,
529
+ current_snapshot_id = None ,
530
+ snapshots = [],
531
+ snapshot_log = [],
532
+ metadata_log = [],
533
+ sort_orders = [SortOrder (order_id = 0 )],
534
+ default_sort_order_id = 0 ,
535
+ refs = {},
536
+ format_version = 2 ,
537
+ last_sequence_number = 0 ,
538
+ )
539
+
540
+ assert metadata .model_dump () == expected .model_dump ()
541
+
542
+
368
543
@patch ("time.time" , MagicMock (return_value = 12345 ))
369
544
def test_create_v1_table (table_schema_simple : Schema , hive_database : HiveDatabase , hive_table : HiveTable ) -> None :
370
545
catalog = HiveCatalog (HIVE_CATALOG_NAME , uri = HIVE_METASTORE_FAKE_URL )
0 commit comments