@@ -30,6 +30,31 @@ use crate::dataframe::default_box_size;
3030use crate :: utils:: RUNTIME ;
3131use crate :: utils:: wait_for_future;
3232
33+ fn resolve_file_path ( path : & str ) -> String {
34+ if path. contains ( "://" ) {
35+ return path. to_owned ( ) ;
36+ }
37+ if path. starts_with ( '/' ) {
38+ return format ! ( "fs://{}" , path) ;
39+ }
40+ format ! (
41+ "fs://{}/{}" ,
42+ std:: env:: current_dir( ) . unwrap( ) . to_str( ) . unwrap( ) ,
43+ path
44+ )
45+ }
46+
47+ fn extract_string_column ( entry : & BlockEntry ) -> Option < & databend_common_expression:: types:: StringColumn > {
48+ match entry {
49+ BlockEntry :: Column ( Column :: String ( col) ) => Some ( col) ,
50+ BlockEntry :: Column ( Column :: Nullable ( n) ) => match & n. column {
51+ Column :: String ( col) => Some ( col) ,
52+ _ => None ,
53+ } ,
54+ _ => None ,
55+ }
56+ }
57+
3358#[ pyclass( name = "SessionContext" , module = "databend" , subclass) ]
3459#[ derive( Clone ) ]
3560pub ( crate ) struct PySessionContext {
@@ -173,104 +198,69 @@ impl PySessionContext {
173198 connection : Option < & str > ,
174199 py : Python ,
175200 ) -> PyResult < ( ) > {
176- // Resolve file path
177- let ( file_path, connection_clause) = if let Some ( connection_name) = connection {
178- (
179- path. to_owned ( ) ,
180- format ! ( ", connection => '{}'" , connection_name) ,
181- )
182- } else {
183- let mut p = path. to_owned ( ) ;
184- if p. starts_with ( '/' ) {
185- p = format ! ( "fs://{}" , p) ;
186- }
187- if !p. contains ( "://" ) {
188- p = format ! (
189- "fs://{}/{}" ,
190- std:: env:: current_dir( ) . unwrap( ) . to_str( ) . unwrap( ) ,
191- p. as_str( )
192- ) ;
193- }
194- ( p, String :: new ( ) )
201+ let file_path = match connection {
202+ Some ( _) => path. to_owned ( ) ,
203+ None => resolve_file_path ( path) ,
195204 } ;
196-
205+ let connection_clause = connection
206+ . map ( |c| format ! ( ", connection => '{}'" , c) )
207+ . unwrap_or_default ( ) ;
197208 let pattern_clause = pattern
198209 . map ( |p| format ! ( ", pattern => '{}'" , p) )
199210 . unwrap_or_default ( ) ;
200211
201- // For CSV/TSV, use infer_schema to get column positions instead of SELECT *
202- let select_clause = if file_format == "csv" || file_format == "tsv" {
203- let col_names =
204- self . infer_column_names ( & file_path, file_format, connection, py) ?;
205- if col_names. is_empty ( ) {
206- return Err ( PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > (
207- "Could not infer schema from CSV/TSV file: no columns found" ,
208- ) ) ;
209- }
210- col_names
211- . iter ( )
212- . enumerate ( )
213- . map ( |( i, col_name) | format ! ( "${} AS `{}`" , i + 1 , col_name) )
214- . collect :: < Vec < _ > > ( )
215- . join ( ", " )
216- } else {
217- "*" . to_string ( )
212+ let select_clause = match file_format {
213+ "csv" | "tsv" => self . build_column_select ( & file_path, file_format, connection, py) ?,
214+ _ => "*" . to_string ( ) ,
218215 } ;
219216
220217 let sql = format ! (
221218 "create view {} as select {} from '{}' (file_format => '{}'{}{})" ,
222219 name, select_clause, file_path, file_format, pattern_clause, connection_clause
223220 ) ;
224-
225221 let _ = self . sql ( & sql, py) ?. collect ( py) ?;
226222 Ok ( ( ) )
227223 }
228224
229- fn infer_column_names (
225+ /// Infer column names via `infer_schema` and build `$1 AS col1, $2 AS col2, ...`.
226+ fn build_column_select (
230227 & mut self ,
231228 file_path : & str ,
232229 file_format : & str ,
233230 connection : Option < & str > ,
234231 py : Python ,
235- ) -> PyResult < Vec < String > > {
236- let connection_clause = connection
232+ ) -> PyResult < String > {
233+ let conn_clause = connection
237234 . map ( |c| format ! ( ", connection_name => '{}'" , c) )
238235 . unwrap_or_default ( ) ;
239-
240- let infer_sql = format ! (
236+ let sql = format ! (
241237 "SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{})" ,
242238 file_path,
243239 file_format. to_uppercase( ) ,
244- connection_clause
240+ conn_clause
245241 ) ;
246242
247- let df = self . sql ( & infer_sql, py) ?;
248- let blocks = df. collect ( py) ?;
249-
250- let mut col_names = Vec :: new ( ) ;
251- for block in & blocks. blocks {
252- if block. num_rows ( ) == 0 {
253- continue ;
254- }
255- let entry = block. get_by_offset ( 0 ) ;
256- match entry {
257- BlockEntry :: Column ( Column :: String ( col) ) => {
258- for val in col. iter ( ) {
259- col_names. push ( val. to_string ( ) ) ;
260- }
261- }
262- BlockEntry :: Column ( Column :: Nullable ( nullable_col) ) => {
263- if let Column :: String ( col) = & nullable_col. column {
264- for val in col. iter ( ) {
265- col_names. push ( val. to_string ( ) ) ;
266- }
267- }
268- }
269- _ => { }
270- }
243+ let blocks = self . sql ( & sql, py) ?. collect ( py) ?;
244+ let col_names: Vec < String > = blocks
245+ . blocks
246+ . iter ( )
247+ . filter ( |b| b. num_rows ( ) > 0 )
248+ . filter_map ( |b| extract_string_column ( b. get_by_offset ( 0 ) ) )
249+ . flat_map ( |col| col. iter ( ) . map ( |s| s. to_string ( ) ) )
250+ . collect ( ) ;
251+
252+ if col_names. is_empty ( ) {
253+ return Err ( PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > (
254+ "Could not infer schema: no columns found" ,
255+ ) ) ;
271256 }
272257
273- Ok ( col_names)
258+ Ok ( col_names
259+ . iter ( )
260+ . enumerate ( )
261+ . map ( |( i, name) | format ! ( "${} AS `{}`" , i + 1 , name) )
262+ . collect :: < Vec < _ > > ( )
263+ . join ( ", " ) )
274264 }
275265
276266 #[ pyo3( signature = ( name, access_key_id, secret_access_key, endpoint_url = None , region = None ) ) ]
0 commit comments