|
15 | 15 | use std::sync::Arc; |
16 | 16 |
|
17 | 17 | use databend_common_exception::Result; |
| 18 | +use databend_common_expression::BlockEntry; |
| 19 | +use databend_common_expression::Column; |
18 | 20 | use databend_common_meta_app::principal::BUILTIN_ROLE_ACCOUNT_ADMIN; |
19 | 21 | use databend_common_version::BUILD_INFO; |
20 | 22 | use databend_query::sessions::BuildInfoRef; |
@@ -171,41 +173,106 @@ impl PySessionContext { |
171 | 173 | connection: Option<&str>, |
172 | 174 | py: Python, |
173 | 175 | ) -> PyResult<()> { |
174 | | - let sql = if let Some(connection_name) = connection { |
175 | | - let pattern_clause = pattern |
176 | | - .map(|p| format!(", pattern => '{}'", p)) |
177 | | - .unwrap_or_default(); |
178 | | - format!( |
179 | | - "create view {} as select * from '{}' (file_format => '{}'{}, connection => '{}')", |
180 | | - name, path, file_format, pattern_clause, connection_name |
| 176 | + // Resolve file path |
| 177 | + let (file_path, connection_clause) = if let Some(connection_name) = connection { |
| 178 | + ( |
| 179 | + path.to_owned(), |
| 180 | + format!(", connection => '{}'", connection_name), |
181 | 181 | ) |
182 | 182 | } else { |
183 | | - let mut path = path.to_owned(); |
184 | | - if path.starts_with('/') { |
185 | | - path = format!("fs://{}", path); |
| 183 | + let mut p = path.to_owned(); |
| 184 | + if p.starts_with('/') { |
| 185 | + p = format!("fs://{}", p); |
186 | 186 | } |
187 | | - |
188 | | - if !path.contains("://") { |
189 | | - path = format!( |
| 187 | + if !p.contains("://") { |
| 188 | + p = format!( |
190 | 189 | "fs://{}/{}", |
191 | 190 | std::env::current_dir().unwrap().to_str().unwrap(), |
192 | | - path.as_str() |
| 191 | + p.as_str() |
193 | 192 | ); |
194 | 193 | } |
| 194 | + (p, String::new()) |
| 195 | + }; |
195 | 196 |
|
196 | | - let pattern_clause = pattern |
197 | | - .map(|p| format!(", pattern => '{}'", p)) |
198 | | - .unwrap_or_default(); |
199 | | - format!( |
200 | | - "create view {} as select * from '{}' (file_format => '{}'{})", |
201 | | - name, path, file_format, pattern_clause |
202 | | - ) |
| 197 | + let pattern_clause = pattern |
| 198 | + .map(|p| format!(", pattern => '{}'", p)) |
| 199 | + .unwrap_or_default(); |
| 200 | + |
| 201 | + // For CSV/TSV, use infer_schema to get column positions instead of SELECT * |
| 202 | + let select_clause = if file_format == "csv" || file_format == "tsv" { |
| 203 | + let col_names = |
| 204 | + self.infer_column_names(&file_path, file_format, connection, py)?; |
| 205 | + if col_names.is_empty() { |
| 206 | + return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>( |
| 207 | + "Could not infer schema from CSV/TSV file: no columns found", |
| 208 | + )); |
| 209 | + } |
| 210 | + col_names |
| 211 | + .iter() |
| 212 | + .enumerate() |
| 213 | + .map(|(i, col_name)| format!("${} AS `{}`", i + 1, col_name)) |
| 214 | + .collect::<Vec<_>>() |
| 215 | + .join(", ") |
| 216 | + } else { |
| 217 | + "*".to_string() |
203 | 218 | }; |
204 | 219 |
|
| 220 | + let sql = format!( |
| 221 | + "create view {} as select {} from '{}' (file_format => '{}'{}{})", |
| 222 | + name, select_clause, file_path, file_format, pattern_clause, connection_clause |
| 223 | + ); |
| 224 | + |
205 | 225 | let _ = self.sql(&sql, py)?.collect(py)?; |
206 | 226 | Ok(()) |
207 | 227 | } |
208 | 228 |
|
| 229 | + fn infer_column_names( |
| 230 | + &mut self, |
| 231 | + file_path: &str, |
| 232 | + file_format: &str, |
| 233 | + connection: Option<&str>, |
| 234 | + py: Python, |
| 235 | + ) -> PyResult<Vec<String>> { |
| 236 | + let connection_clause = connection |
| 237 | + .map(|c| format!(", connection_name => '{}'", c)) |
| 238 | + .unwrap_or_default(); |
| 239 | + |
| 240 | + let infer_sql = format!( |
| 241 | + "SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{})", |
| 242 | + file_path, |
| 243 | + file_format.to_uppercase(), |
| 244 | + connection_clause |
| 245 | + ); |
| 246 | + |
| 247 | + let df = self.sql(&infer_sql, py)?; |
| 248 | + let blocks = df.collect(py)?; |
| 249 | + |
| 250 | + let mut col_names = Vec::new(); |
| 251 | + for block in &blocks.blocks { |
| 252 | + if block.num_rows() == 0 { |
| 253 | + continue; |
| 254 | + } |
| 255 | + let entry = block.get_by_offset(0); |
| 256 | + match entry { |
| 257 | + BlockEntry::Column(Column::String(col)) => { |
| 258 | + for val in col.iter() { |
| 259 | + col_names.push(val.to_string()); |
| 260 | + } |
| 261 | + } |
| 262 | + BlockEntry::Column(Column::Nullable(nullable_col)) => { |
| 263 | + if let Column::String(col) = &nullable_col.column { |
| 264 | + for val in col.iter() { |
| 265 | + col_names.push(val.to_string()); |
| 266 | + } |
| 267 | + } |
| 268 | + } |
| 269 | + _ => {} |
| 270 | + } |
| 271 | + } |
| 272 | + |
| 273 | + Ok(col_names) |
| 274 | + } |
| 275 | + |
209 | 276 | #[pyo3(signature = (name, access_key_id, secret_access_key, endpoint_url = None, region = None))] |
210 | 277 | fn create_s3_connection( |
211 | 278 | &mut self, |
|
0 commit comments