|
1 | 1 | use datafusion::{ |
| 2 | + arrow::{ |
| 3 | + array::{Array, ArrayRef, DictionaryArray, StringArray, StringViewArray}, |
| 4 | + datatypes::{DataType, Field, Schema, UInt16Type}, |
| 5 | + record_batch::RecordBatch, |
| 6 | + }, |
2 | 7 | common::{internal_datafusion_err, internal_err}, |
3 | 8 | error::Result, |
4 | 9 | execution::context::SessionContext, |
5 | 10 | prelude::ParquetReadOptions, |
6 | 11 | }; |
| 12 | +use parquet::{arrow::ArrowWriter, file::properties::WriterProperties}; |
7 | 13 | use std::fs; |
8 | 14 | use std::path::Path; |
9 | 15 | use std::process::Command; |
| 16 | +use std::sync::Arc; |
10 | 17 |
|
11 | 18 | pub fn get_data_dir() -> std::path::PathBuf { |
12 | 19 | std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("testdata/tpcds/data") |
@@ -74,15 +81,147 @@ pub const TPCDS_TABLES: &[&str] = &[ |
74 | 81 | "web_site", |
75 | 82 | ]; |
76 | 83 |
|
| 84 | +/// Tables that should have dictionary encoding applied for testing |
| 85 | +const DICT_ENCODING_TABLES: &[&str] = &["item", "customer", "store"]; |
| 86 | + |
| 87 | +/// Force dictionary encoding for specific string columns in a table for extra test coverage. |
| 88 | +fn force_dictionary_encoding_for_table( |
| 89 | + table_name: &str, |
| 90 | + batch: RecordBatch, |
| 91 | +) -> Result<RecordBatch> { |
| 92 | + let dict_columns = match table_name { |
| 93 | + "item" => vec!["i_brand", "i_category", "i_class", "i_color", "i_size"], |
| 94 | + "customer" => vec!["c_salutation"], |
| 95 | + "store" => vec!["s_state", "s_country"], |
| 96 | + _ => vec![], // No dictionary encoding for other tables |
| 97 | + }; |
| 98 | + |
| 99 | + if dict_columns.is_empty() { |
| 100 | + return Ok(batch); |
| 101 | + } |
| 102 | + |
| 103 | + let schema = batch.schema(); |
| 104 | + let mut new_fields = Vec::new(); |
| 105 | + let mut new_columns = Vec::new(); |
| 106 | + |
| 107 | + for (i, field) in schema.fields().iter().enumerate() { |
| 108 | + let column = batch.column(i); |
| 109 | + |
| 110 | + // Check if this column should be dictionary-encoded |
| 111 | + if dict_columns.contains(&field.name().as_str()) |
| 112 | + && matches!(field.data_type(), DataType::Utf8 | DataType::Utf8View) |
| 113 | + { |
| 114 | + // Convert to dictionary encoding |
| 115 | + let string_data = |
| 116 | + if let Some(string_array) = column.as_any().downcast_ref::<StringArray>() { |
| 117 | + string_array.iter().collect::<Vec<_>>() |
| 118 | + } else if let Some(view_array) = column.as_any().downcast_ref::<StringViewArray>() { |
| 119 | + view_array.iter().collect::<Vec<_>>() |
| 120 | + } else { |
| 121 | + return internal_err!("Expected string array for column {}", field.name()); |
| 122 | + }; |
| 123 | + |
| 124 | + let dict_array: DictionaryArray<UInt16Type> = string_data.into_iter().collect(); |
| 125 | + let dict_field = Field::new( |
| 126 | + field.name(), |
| 127 | + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), |
| 128 | + field.is_nullable(), |
| 129 | + ); |
| 130 | + |
| 131 | + new_fields.push(dict_field); |
| 132 | + new_columns.push(Arc::new(dict_array) as ArrayRef); |
| 133 | + } else { |
| 134 | + new_fields.push((**field).clone()); |
| 135 | + new_columns.push(column.clone()); |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + let new_schema = Arc::new(Schema::new(new_fields)); |
| 140 | + RecordBatch::try_new(new_schema, new_columns).map_err(|e| internal_datafusion_err!("{}", e)) |
| 141 | +} |
| 142 | + |
77 | 143 | pub async fn register_tpcds_table( |
78 | 144 | ctx: &SessionContext, |
79 | 145 | table_name: &str, |
80 | 146 | data_dir: Option<&Path>, |
| 147 | +) -> Result<()> { |
| 148 | + register_tpcds_table_with_options(ctx, table_name, data_dir, false).await |
| 149 | +} |
| 150 | + |
| 151 | +pub async fn register_tpcds_table_with_options( |
| 152 | + ctx: &SessionContext, |
| 153 | + table_name: &str, |
| 154 | + data_dir: Option<&Path>, |
| 155 | + dict_encode_items_table: bool, |
81 | 156 | ) -> Result<()> { |
82 | 157 | let default_data_dir = get_data_dir(); |
83 | 158 | let data_path = data_dir.unwrap_or(&default_data_dir); |
84 | 159 |
|
85 | | - // Check if this is a single parquet file |
| 160 | + // Apply dictionary encoding if requested and materialize to disk |
| 161 | + if dict_encode_items_table && DICT_ENCODING_TABLES.contains(&table_name) { |
| 162 | + let table_dir_path = data_path.join(table_name); |
| 163 | + if table_dir_path.is_dir() { |
| 164 | + let dict_table_path = data_path.join(format!("{table_name}_dict")); |
| 165 | + |
| 166 | + // Check if dictionary encoded version already exists |
| 167 | + if dict_table_path.exists() { |
| 168 | + // Use the existing dictionary encoded version |
| 169 | + ctx.register_parquet( |
| 170 | + table_name, |
| 171 | + &dict_table_path.to_string_lossy(), |
| 172 | + ParquetReadOptions::default(), |
| 173 | + ) |
| 174 | + .await?; |
| 175 | + return Ok(()); |
| 176 | + } |
| 177 | + |
| 178 | + // Register temporarily to read the original data |
| 179 | + let temp_table_name = format!("temp_{table_name}"); |
| 180 | + ctx.register_parquet( |
| 181 | + &temp_table_name, |
| 182 | + &table_dir_path.to_string_lossy(), |
| 183 | + ParquetReadOptions::default(), |
| 184 | + ) |
| 185 | + .await?; |
| 186 | + |
| 187 | + // Read data and apply dictionary encoding |
| 188 | + let df = ctx.table(&temp_table_name).await?; |
| 189 | + let batches = df.collect().await?; |
| 190 | + |
| 191 | + let mut dict_batches = Vec::new(); |
| 192 | + for batch in batches { |
| 193 | + dict_batches.push(force_dictionary_encoding_for_table(table_name, batch)?); |
| 194 | + } |
| 195 | + |
| 196 | + // Write dictionary-encoded data to disk |
| 197 | + if !dict_batches.is_empty() { |
| 198 | + fs::create_dir_all(&dict_table_path)?; |
| 199 | + let dict_file_path = dict_table_path.join("data.parquet"); |
| 200 | + let file = fs::File::create(&dict_file_path)?; |
| 201 | + let props = WriterProperties::builder().build(); |
| 202 | + let mut writer = ArrowWriter::try_new(file, dict_batches[0].schema(), Some(props))?; |
| 203 | + |
| 204 | + for batch in &dict_batches { |
| 205 | + writer.write(batch)?; |
| 206 | + } |
| 207 | + writer.close()?; |
| 208 | + |
| 209 | + // Register the dictionary encoded table |
| 210 | + ctx.register_parquet( |
| 211 | + table_name, |
| 212 | + &dict_table_path.to_string_lossy(), |
| 213 | + ParquetReadOptions::default(), |
| 214 | + ) |
| 215 | + .await?; |
| 216 | + } |
| 217 | + |
| 218 | + // Deregister the temporary table |
| 219 | + ctx.deregister_table(&temp_table_name)?; |
| 220 | + return Ok(()); |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + // Use normal parquet registration for all tables |
86 | 225 | let table_file_path = data_path.join(format!("{table_name}.parquet")); |
87 | 226 | if table_file_path.is_file() { |
88 | 227 | ctx.register_parquet( |
@@ -113,10 +252,17 @@ pub async fn register_tpcds_table( |
113 | 252 | } |
114 | 253 |
|
115 | 254 | pub async fn register_tables(ctx: &SessionContext) -> Result<Vec<String>> { |
| 255 | + register_tables_with_options(ctx, false).await |
| 256 | +} |
| 257 | + |
| 258 | +pub async fn register_tables_with_options( |
| 259 | + ctx: &SessionContext, |
| 260 | + dict_encode_items_table: bool, |
| 261 | +) -> Result<Vec<String>> { |
116 | 262 | let mut registered_tables = Vec::new(); |
117 | 263 |
|
118 | 264 | for &table_name in TPCDS_TABLES { |
119 | | - register_tpcds_table(ctx, table_name, None).await?; |
| 265 | + register_tpcds_table_with_options(ctx, table_name, None, dict_encode_items_table).await?; |
120 | 266 | registered_tables.push(table_name.to_string()); |
121 | 267 | } |
122 | 268 |
|
|
0 commit comments