Skip to content

Commit 0de3248

Browse files
committed
refactor: simplify register_csv/tsv implementation
- Extract resolve_file_path() and extract_string_column() as standalone helpers - Replace imperative loop with functional iterator chain - Rename infer_column_names to build_column_select for clarity - Deduplicate mock logic in test_connections.py via _register_delimited()
1 parent ed97583 commit 0de3248

File tree

2 files changed

+75
-93
lines changed

2 files changed

+75
-93
lines changed

src/bendpy/src/context.rs

Lines changed: 59 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,31 @@ use crate::dataframe::default_box_size;
3030
use crate::utils::RUNTIME;
3131
use crate::utils::wait_for_future;
3232

33+
fn resolve_file_path(path: &str) -> String {
34+
if path.contains("://") {
35+
return path.to_owned();
36+
}
37+
if path.starts_with('/') {
38+
return format!("fs://{}", path);
39+
}
40+
format!(
41+
"fs://{}/{}",
42+
std::env::current_dir().unwrap().to_str().unwrap(),
43+
path
44+
)
45+
}
46+
47+
fn extract_string_column(entry: &BlockEntry) -> Option<&databend_common_expression::types::StringColumn> {
48+
match entry {
49+
BlockEntry::Column(Column::String(col)) => Some(col),
50+
BlockEntry::Column(Column::Nullable(n)) => match &n.column {
51+
Column::String(col) => Some(col),
52+
_ => None,
53+
},
54+
_ => None,
55+
}
56+
}
57+
3358
#[pyclass(name = "SessionContext", module = "databend", subclass)]
3459
#[derive(Clone)]
3560
pub(crate) struct PySessionContext {
@@ -173,104 +198,69 @@ impl PySessionContext {
173198
connection: Option<&str>,
174199
py: Python,
175200
) -> PyResult<()> {
176-
// Resolve file path
177-
let (file_path, connection_clause) = if let Some(connection_name) = connection {
178-
(
179-
path.to_owned(),
180-
format!(", connection => '{}'", connection_name),
181-
)
182-
} else {
183-
let mut p = path.to_owned();
184-
if p.starts_with('/') {
185-
p = format!("fs://{}", p);
186-
}
187-
if !p.contains("://") {
188-
p = format!(
189-
"fs://{}/{}",
190-
std::env::current_dir().unwrap().to_str().unwrap(),
191-
p.as_str()
192-
);
193-
}
194-
(p, String::new())
201+
let file_path = match connection {
202+
Some(_) => path.to_owned(),
203+
None => resolve_file_path(path),
195204
};
196-
205+
let connection_clause = connection
206+
.map(|c| format!(", connection => '{}'", c))
207+
.unwrap_or_default();
197208
let pattern_clause = pattern
198209
.map(|p| format!(", pattern => '{}'", p))
199210
.unwrap_or_default();
200211

201-
// For CSV/TSV, use infer_schema to get column positions instead of SELECT *
202-
let select_clause = if file_format == "csv" || file_format == "tsv" {
203-
let col_names =
204-
self.infer_column_names(&file_path, file_format, connection, py)?;
205-
if col_names.is_empty() {
206-
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
207-
"Could not infer schema from CSV/TSV file: no columns found",
208-
));
209-
}
210-
col_names
211-
.iter()
212-
.enumerate()
213-
.map(|(i, col_name)| format!("${} AS `{}`", i + 1, col_name))
214-
.collect::<Vec<_>>()
215-
.join(", ")
216-
} else {
217-
"*".to_string()
212+
let select_clause = match file_format {
213+
"csv" | "tsv" => self.build_column_select(&file_path, file_format, connection, py)?,
214+
_ => "*".to_string(),
218215
};
219216

220217
let sql = format!(
221218
"create view {} as select {} from '{}' (file_format => '{}'{}{})",
222219
name, select_clause, file_path, file_format, pattern_clause, connection_clause
223220
);
224-
225221
let _ = self.sql(&sql, py)?.collect(py)?;
226222
Ok(())
227223
}
228224

229-
fn infer_column_names(
225+
/// Infer column names via `infer_schema` and build `$1 AS col1, $2 AS col2, ...`.
226+
fn build_column_select(
230227
&mut self,
231228
file_path: &str,
232229
file_format: &str,
233230
connection: Option<&str>,
234231
py: Python,
235-
) -> PyResult<Vec<String>> {
236-
let connection_clause = connection
232+
) -> PyResult<String> {
233+
let conn_clause = connection
237234
.map(|c| format!(", connection_name => '{}'", c))
238235
.unwrap_or_default();
239-
240-
let infer_sql = format!(
236+
let sql = format!(
241237
"SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{})",
242238
file_path,
243239
file_format.to_uppercase(),
244-
connection_clause
240+
conn_clause
245241
);
246242

247-
let df = self.sql(&infer_sql, py)?;
248-
let blocks = df.collect(py)?;
249-
250-
let mut col_names = Vec::new();
251-
for block in &blocks.blocks {
252-
if block.num_rows() == 0 {
253-
continue;
254-
}
255-
let entry = block.get_by_offset(0);
256-
match entry {
257-
BlockEntry::Column(Column::String(col)) => {
258-
for val in col.iter() {
259-
col_names.push(val.to_string());
260-
}
261-
}
262-
BlockEntry::Column(Column::Nullable(nullable_col)) => {
263-
if let Column::String(col) = &nullable_col.column {
264-
for val in col.iter() {
265-
col_names.push(val.to_string());
266-
}
267-
}
268-
}
269-
_ => {}
270-
}
243+
let blocks = self.sql(&sql, py)?.collect(py)?;
244+
let col_names: Vec<String> = blocks
245+
.blocks
246+
.iter()
247+
.filter(|b| b.num_rows() > 0)
248+
.filter_map(|b| extract_string_column(b.get_by_offset(0)))
249+
.flat_map(|col| col.iter().map(|s| s.to_string()))
250+
.collect();
251+
252+
if col_names.is_empty() {
253+
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
254+
"Could not infer schema: no columns found",
255+
));
271256
}
272257

273-
Ok(col_names)
258+
Ok(col_names
259+
.iter()
260+
.enumerate()
261+
.map(|(i, name)| format!("${} AS `{}`", i + 1, name))
262+
.collect::<Vec<_>>()
263+
.join(", "))
274264
}
275265

276266
#[pyo3(signature = (name, access_key_id, secret_access_key, endpoint_url = None, region = None))]

src/bendpy/tests/test_connections.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,22 @@ def register_parquet(self, name, path, pattern=None, connection=None):
4444
self.sql(sql)
4545

4646
def register_csv(self, name, path, pattern=None, connection=None):
47-
if connection:
48-
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
49-
connection_clause = f", connection => '{connection}'"
50-
# Infer schema first for CSV
51-
infer_conn = f", connection_name => '{connection}'"
52-
self.sql(
53-
f"SELECT column_name FROM infer_schema(location => '{path}', file_format => 'CSV'{infer_conn})"
54-
)
55-
# Use column positions from infer_schema (simulated as 3 columns)
56-
select_clause = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
57-
sql = f"create view {name} as select {select_clause} from '{path}' (file_format => 'csv'{pattern_clause}{connection_clause})"
58-
else:
59-
p = path
60-
if p.startswith("/"):
61-
p = f"fs://{p}"
62-
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
63-
# Infer schema first for CSV
64-
self.sql(
65-
f"SELECT column_name FROM infer_schema(location => '{p}', file_format => 'CSV')"
66-
)
67-
# Use column positions from infer_schema (simulated as 3 columns)
68-
select_clause = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
69-
sql = f"create view {name} as select {select_clause} from '{p}' (file_format => 'csv'{pattern_clause})"
70-
self.sql(sql)
47+
self._register_delimited(name, path, "csv", pattern, connection)
48+
49+
def _register_delimited(self, name, path, fmt, pattern=None, connection=None):
50+
"""CSV/TSV: infer schema first, then create view with column positions."""
51+
file_path = path if connection else (f"fs://{path}" if path.startswith("/") else path)
52+
conn_infer = f", connection_name => '{connection}'" if connection else ""
53+
self.sql(
54+
f"SELECT column_name FROM infer_schema(location => '{file_path}', file_format => '{fmt.upper()}'{conn_infer})"
55+
)
56+
# Simulated: infer_schema returns 3 columns
57+
select = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
58+
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
59+
conn_clause = f", connection => '{connection}'" if connection else ""
60+
self.sql(
61+
f"create view {name} as select {select} from '{file_path}' (file_format => '{fmt}'{pattern_clause}{conn_clause})"
62+
)
7163

7264
def create_azblob_connection(self, name, endpoint_url, account_name, account_key):
7365
sql = f"CREATE OR REPLACE CONNECTION {name} STORAGE_TYPE = 'AZBLOB' endpoint_url = '{endpoint_url}' account_name = '{account_name}' account_key = '{account_key}'"

0 commit comments

Comments
 (0)