diff --git a/crates/integrations/datafusion/src/table/mod.rs b/crates/integrations/datafusion/src/table/mod.rs index 7f741a534a..8cdeb6654e 100644 --- a/crates/integrations/datafusion/src/table/mod.rs +++ b/crates/integrations/datafusion/src/table/mod.rs @@ -24,17 +24,23 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::catalog::Session; +use datafusion::common::DataFusionError; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result as DFResult; +use datafusion::logical_expr::dml::InsertOp; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use iceberg::arrow::schema_to_arrow_schema; use iceberg::inspect::MetadataTableType; use iceberg::table::Table; use iceberg::{Catalog, Error, ErrorKind, NamespaceIdent, Result, TableIdent}; use metadata_table::IcebergMetadataTableProvider; +use crate::physical_plan::commit::IcebergCommitExec; use crate::physical_plan::scan::IcebergTableScan; +use crate::physical_plan::write::IcebergWriteExec; +use crate::to_datafusion_error; /// Represents a [`TableProvider`] for the Iceberg [`Catalog`], /// managing access to a [`Table`]. @@ -46,6 +52,8 @@ pub struct IcebergTableProvider { snapshot_id: Option, /// A reference-counted arrow `Schema`. schema: ArrowSchemaRef, + /// The catalog that the table belongs to. + catalog: Option>, } impl IcebergTableProvider { @@ -54,6 +62,7 @@ impl IcebergTableProvider { table, snapshot_id: None, schema, + catalog: None, } } /// Asynchronously tries to construct a new [`IcebergTableProvider`] @@ -73,6 +82,7 @@ impl IcebergTableProvider { table, snapshot_id: None, schema, + catalog: Some(client), }) } @@ -84,6 +94,7 @@ impl IcebergTableProvider { table, snapshot_id: None, schema, + catalog: None, }) } @@ -108,6 +119,7 @@ impl IcebergTableProvider { table, snapshot_id: Some(snapshot_id), schema, + catalog: None, }) } @@ -140,8 +152,18 @@ impl TableProvider for IcebergTableProvider { filters: &[Expr], _limit: Option, ) -> DFResult> { + // Refresh table if catalog is available + let table = if let Some(catalog) = &self.catalog { + catalog + .load_table(self.table.identifier()) + .await + .map_err(to_datafusion_error)? + } else { + self.table.clone() + }; + Ok(Arc::new(IcebergTableScan::new( - self.table.clone(), + table, self.snapshot_id, self.schema.clone(), projection, @@ -152,11 +174,52 @@ impl TableProvider for IcebergTableProvider { fn supports_filters_pushdown( &self, filters: &[&Expr], - ) -> std::result::Result, datafusion::error::DataFusionError> - { + ) -> DFResult> { // Push down all filters, as a single source of truth, the scanner will drop the filters which couldn't be push down Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) } + + async fn insert_into( + &self, + _state: &dyn Session, + input: Arc, + _insert_op: InsertOp, + ) -> DFResult> { + if !self + .table + .metadata() + .default_partition_spec() + .is_unpartitioned() + { + // TODO add insert into support for partitioned tables + return Err(DataFusionError::NotImplemented( + "IcebergTableProvider::insert_into does not support partitioned tables yet" + .to_string(), + )); + } + + let Some(catalog) = self.catalog.clone() else { + return Err(DataFusionError::Execution( + "Catalog cannot be none for insert_into".to_string(), + )); + }; + + let write_plan = Arc::new(IcebergWriteExec::new( + self.table.clone(), + input, + self.schema.clone(), + )); + + // Merge the outputs of write_plan into one so we can commit all files together + let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(write_plan)); + + Ok(Arc::new(IcebergCommitExec::new( + self.table.clone(), + catalog, + coalesce_partitions, + self.schema.clone(), + ))) + } } #[cfg(test)] diff --git a/crates/integrations/datafusion/tests/integration_datafusion_test.rs b/crates/integrations/datafusion/tests/integration_datafusion_test.rs index 1491e4dbff..56d43ae04e 100644 --- a/crates/integrations/datafusion/tests/integration_datafusion_test.rs +++ b/crates/integrations/datafusion/tests/integration_datafusion_test.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::vec; -use datafusion::arrow::array::{Array, StringArray}; +use datafusion::arrow::array::{Array, StringArray, UInt64Array}; use datafusion::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datafusion::execution::context::SessionContext; use datafusion::parquet::arrow::PARQUET_FIELD_ID_META_KEY; @@ -432,3 +432,370 @@ async fn test_metadata_table() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_insert_into() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("test_insert_into".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + + let creation = get_table_creation(temp_path(), "my_table", None)?; + iceberg_catalog.create_table(&namespace, creation).await?; + + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + + let ctx = SessionContext::new(); + ctx.register_catalog("catalog", catalog); + + // Verify table schema + let provider = ctx.catalog("catalog").unwrap(); + let schema = provider.schema("test_insert_into").unwrap(); + let table = schema.table("my_table").await.unwrap().unwrap(); + let table_schema = table.schema(); + + let expected = [("foo1", &DataType::Int32), ("foo2", &DataType::Utf8)]; + for (field, exp) in table_schema.fields().iter().zip(expected.iter()) { + assert_eq!(field.name(), exp.0); + assert_eq!(field.data_type(), exp.1); + assert!(!field.is_nullable()) + } + + // Insert data into the table + let df = ctx + .sql("INSERT INTO catalog.test_insert_into.my_table VALUES (1, 'alan'), (2, 'turing')") + .await + .unwrap(); + + // Verify the insert operation result + let batches = df.collect().await.unwrap(); + assert_eq!(batches.len(), 1); + let batch = &batches[0]; + assert!( + batch.num_rows() == 1 && batch.num_columns() == 1, + "Results should only have one row and one column that has the number of rows inserted" + ); + // Verify the number of rows inserted + let rows_inserted = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(rows_inserted.value(0), 2); + + // Query the table to verify the inserted data + let df = ctx + .sql("SELECT * FROM catalog.test_insert_into.my_table") + .await + .unwrap(); + + let batches = df.collect().await.unwrap(); + + // Use check_record_batches to verify the data + check_record_batches( + batches, + expect![[r#" + Field { name: "foo1", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "1"} }, + Field { name: "foo2", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "2"} }"#]], + expect![[r#" + foo1: PrimitiveArray + [ + 1, + 2, + ], + foo2: StringArray + [ + "alan", + "turing", + ]"#]], + &[], + Some("foo1"), + ); + + Ok(()) +} + +fn get_nested_struct_type() -> StructType { + // Create a nested struct type with: + // - address: STRUCT + // - contact: STRUCT + StructType::new(vec![ + NestedField::optional( + 10, + "address", + Type::Struct(StructType::new(vec![ + NestedField::required(11, "street", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(12, "city", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(13, "zip", Type::Primitive(PrimitiveType::Int)).into(), + ])), + ) + .into(), + NestedField::optional( + 20, + "contact", + Type::Struct(StructType::new(vec![ + NestedField::optional(21, "email", Type::Primitive(PrimitiveType::String)).into(), + NestedField::optional(22, "phone", Type::Primitive(PrimitiveType::String)).into(), + ])), + ) + .into(), + ]) +} + +#[tokio::test] +async fn test_insert_into_nested() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("test_insert_nested".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + let table_name = "nested_table"; + + // Create a schema with nested fields + let schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(), + NestedField::optional(3, "profile", Type::Struct(get_nested_struct_type())).into(), + ]) + .build()?; + + // Create the table with the nested schema + let creation = get_table_creation(temp_path(), table_name, Some(schema))?; + iceberg_catalog.create_table(&namespace, creation).await?; + + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + + let ctx = SessionContext::new(); + ctx.register_catalog("catalog", catalog); + + // Verify table schema + let provider = ctx.catalog("catalog").unwrap(); + let schema = provider.schema("test_insert_nested").unwrap(); + let table = schema.table("nested_table").await.unwrap().unwrap(); + let table_schema = table.schema(); + + // Verify the schema has the expected structure + assert_eq!(table_schema.fields().len(), 3); + assert_eq!(table_schema.field(0).name(), "id"); + assert_eq!(table_schema.field(1).name(), "name"); + assert_eq!(table_schema.field(2).name(), "profile"); + assert!(matches!( + table_schema.field(2).data_type(), + DataType::Struct(_) + )); + + // In DataFusion, we need to use named_struct to create struct values + // Insert data with nested structs + let insert_sql = r#" + INSERT INTO catalog.test_insert_nested.nested_table + SELECT + 1 as id, + 'Alice' as name, + named_struct( + 'address', named_struct( + 'street', '123 Main St', + 'city', 'San Francisco', + 'zip', 94105 + ), + 'contact', named_struct( + 'email', 'alice@example.com', + 'phone', '555-1234' + ) + ) as profile + UNION ALL + SELECT + 2 as id, + 'Bob' as name, + named_struct( + 'address', named_struct( + 'street', '456 Market St', + 'city', 'San Jose', + 'zip', 95113 + ), + 'contact', named_struct( + 'email', 'bob@example.com', + 'phone', NULL + ) + ) as profile + "#; + + // Execute the insert + let df = ctx.sql(insert_sql).await.unwrap(); + let batches = df.collect().await.unwrap(); + + // Verify the insert operation result + assert_eq!(batches.len(), 1); + let batch = &batches[0]; + assert!(batch.num_rows() == 1 && batch.num_columns() == 1); + + // Verify the number of rows inserted + let rows_inserted = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(rows_inserted.value(0), 2); + + ctx.refresh_catalogs().await.unwrap(); + + // Query the table to verify the inserted data + let df = ctx + .sql("SELECT * FROM catalog.test_insert_nested.nested_table ORDER BY id") + .await + .unwrap(); + + let batches = df.collect().await.unwrap(); + + // Use check_record_batches to verify the data + check_record_batches( + batches, + expect![[r#" + Field { name: "id", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "1"} }, + Field { name: "name", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "2"} }, + Field { name: "profile", data_type: Struct([Field { name: "address", data_type: Struct([Field { name: "street", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "6"} }, Field { name: "city", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "7"} }, Field { name: "zip", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "8"} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "4"} }, Field { name: "contact", data_type: Struct([Field { name: "email", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "9"} }, Field { name: "phone", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "10"} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "5"} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "3"} }"#]], + expect![[r#" + id: PrimitiveArray + [ + 1, + 2, + ], + name: StringArray + [ + "Alice", + "Bob", + ], + profile: StructArray + -- validity: + [ + valid, + valid, + ] + [ + -- child 0: "address" (Struct([Field { name: "street", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "6"} }, Field { name: "city", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "7"} }, Field { name: "zip", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "8"} }])) + StructArray + -- validity: + [ + valid, + valid, + ] + [ + -- child 0: "street" (Utf8) + StringArray + [ + "123 Main St", + "456 Market St", + ] + -- child 1: "city" (Utf8) + StringArray + [ + "San Francisco", + "San Jose", + ] + -- child 2: "zip" (Int32) + PrimitiveArray + [ + 94105, + 95113, + ] + ] + -- child 1: "contact" (Struct([Field { name: "email", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "9"} }, Field { name: "phone", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "10"} }])) + StructArray + -- validity: + [ + valid, + valid, + ] + [ + -- child 0: "email" (Utf8) + StringArray + [ + "alice@example.com", + "bob@example.com", + ] + -- child 1: "phone" (Utf8) + StringArray + [ + "555-1234", + null, + ] + ] + ]"#]], + &[], + Some("id"), + ); + + // Query with explicit field access to verify nested data + let df = ctx + .sql( + r#" + SELECT + id, + name, + profile.address.street, + profile.address.city, + profile.address.zip, + profile.contact.email, + profile.contact.phone + FROM catalog.test_insert_nested.nested_table + ORDER BY id + "#, + ) + .await + .unwrap(); + + let batches = df.collect().await.unwrap(); + + // Use check_record_batches to verify the flattened data + check_record_batches( + batches, + expect![[r#" + Field { name: "id", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "1"} }, + Field { name: "name", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {"PARQUET:field_id": "2"} }, + Field { name: "catalog.test_insert_nested.nested_table.profile[address][street]", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, + Field { name: "catalog.test_insert_nested.nested_table.profile[address][city]", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, + Field { name: "catalog.test_insert_nested.nested_table.profile[address][zip]", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, + Field { name: "catalog.test_insert_nested.nested_table.profile[contact][email]", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, + Field { name: "catalog.test_insert_nested.nested_table.profile[contact][phone]", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }"#]], + expect![[r#" + id: PrimitiveArray + [ + 1, + 2, + ], + name: StringArray + [ + "Alice", + "Bob", + ], + catalog.test_insert_nested.nested_table.profile[address][street]: StringArray + [ + "123 Main St", + "456 Market St", + ], + catalog.test_insert_nested.nested_table.profile[address][city]: StringArray + [ + "San Francisco", + "San Jose", + ], + catalog.test_insert_nested.nested_table.profile[address][zip]: PrimitiveArray + [ + 94105, + 95113, + ], + catalog.test_insert_nested.nested_table.profile[contact][email]: StringArray + [ + "alice@example.com", + "bob@example.com", + ], + catalog.test_insert_nested.nested_table.profile[contact][phone]: StringArray + [ + "555-1234", + null, + ]"#]], + &[], + Some("id"), + ); + + Ok(()) +}