Skip to content

Commit ed97583

Browse files
committed
fix(bendpy): use infer_schema for CSV/TSV column positions in register_csv/register_tsv
Fixes #19443
1 parent 4350e54 commit ed97583

File tree

3 files changed

+156
-27
lines changed

3 files changed

+156
-27
lines changed

src/bendpy/src/context.rs

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
use std::sync::Arc;
1616

1717
use databend_common_exception::Result;
18+
use databend_common_expression::BlockEntry;
19+
use databend_common_expression::Column;
1820
use databend_common_meta_app::principal::BUILTIN_ROLE_ACCOUNT_ADMIN;
1921
use databend_common_version::BUILD_INFO;
2022
use databend_query::sessions::BuildInfoRef;
@@ -171,41 +173,106 @@ impl PySessionContext {
171173
connection: Option<&str>,
172174
py: Python,
173175
) -> PyResult<()> {
174-
let sql = if let Some(connection_name) = connection {
175-
let pattern_clause = pattern
176-
.map(|p| format!(", pattern => '{}'", p))
177-
.unwrap_or_default();
178-
format!(
179-
"create view {} as select * from '{}' (file_format => '{}'{}, connection => '{}')",
180-
name, path, file_format, pattern_clause, connection_name
176+
// Resolve file path
177+
let (file_path, connection_clause) = if let Some(connection_name) = connection {
178+
(
179+
path.to_owned(),
180+
format!(", connection => '{}'", connection_name),
181181
)
182182
} else {
183-
let mut path = path.to_owned();
184-
if path.starts_with('/') {
185-
path = format!("fs://{}", path);
183+
let mut p = path.to_owned();
184+
if p.starts_with('/') {
185+
p = format!("fs://{}", p);
186186
}
187-
188-
if !path.contains("://") {
189-
path = format!(
187+
if !p.contains("://") {
188+
p = format!(
190189
"fs://{}/{}",
191190
std::env::current_dir().unwrap().to_str().unwrap(),
192-
path.as_str()
191+
p.as_str()
193192
);
194193
}
194+
(p, String::new())
195+
};
195196

196-
let pattern_clause = pattern
197-
.map(|p| format!(", pattern => '{}'", p))
198-
.unwrap_or_default();
199-
format!(
200-
"create view {} as select * from '{}' (file_format => '{}'{})",
201-
name, path, file_format, pattern_clause
202-
)
197+
let pattern_clause = pattern
198+
.map(|p| format!(", pattern => '{}'", p))
199+
.unwrap_or_default();
200+
201+
// For CSV/TSV, use infer_schema to get column positions instead of SELECT *
202+
let select_clause = if file_format == "csv" || file_format == "tsv" {
203+
let col_names =
204+
self.infer_column_names(&file_path, file_format, connection, py)?;
205+
if col_names.is_empty() {
206+
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
207+
"Could not infer schema from CSV/TSV file: no columns found",
208+
));
209+
}
210+
col_names
211+
.iter()
212+
.enumerate()
213+
.map(|(i, col_name)| format!("${} AS `{}`", i + 1, col_name))
214+
.collect::<Vec<_>>()
215+
.join(", ")
216+
} else {
217+
"*".to_string()
203218
};
204219

220+
let sql = format!(
221+
"create view {} as select {} from '{}' (file_format => '{}'{}{})",
222+
name, select_clause, file_path, file_format, pattern_clause, connection_clause
223+
);
224+
205225
let _ = self.sql(&sql, py)?.collect(py)?;
206226
Ok(())
207227
}
208228

229+
fn infer_column_names(
230+
&mut self,
231+
file_path: &str,
232+
file_format: &str,
233+
connection: Option<&str>,
234+
py: Python,
235+
) -> PyResult<Vec<String>> {
236+
let connection_clause = connection
237+
.map(|c| format!(", connection_name => '{}'", c))
238+
.unwrap_or_default();
239+
240+
let infer_sql = format!(
241+
"SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{})",
242+
file_path,
243+
file_format.to_uppercase(),
244+
connection_clause
245+
);
246+
247+
let df = self.sql(&infer_sql, py)?;
248+
let blocks = df.collect(py)?;
249+
250+
let mut col_names = Vec::new();
251+
for block in &blocks.blocks {
252+
if block.num_rows() == 0 {
253+
continue;
254+
}
255+
let entry = block.get_by_offset(0);
256+
match entry {
257+
BlockEntry::Column(Column::String(col)) => {
258+
for val in col.iter() {
259+
col_names.push(val.to_string());
260+
}
261+
}
262+
BlockEntry::Column(Column::Nullable(nullable_col)) => {
263+
if let Column::String(col) = &nullable_col.column {
264+
for val in col.iter() {
265+
col_names.push(val.to_string());
266+
}
267+
}
268+
}
269+
_ => {}
270+
}
271+
}
272+
273+
Ok(col_names)
274+
}
275+
209276
#[pyo3(signature = (name, access_key_id, secret_access_key, endpoint_url = None, region = None))]
210277
fn create_s3_connection(
211278
&mut self,

src/bendpy/tests/test_basic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from databend import SessionContext
1717
import pandas as pd
1818
import polars
19+
import tempfile
20+
import os
1921

2022

2123
class TestBasic:
@@ -60,3 +62,32 @@ def test_create_insert_select(self):
6062
"select sum(a) x, max(b) y, max(d) z from aa where c"
6163
).to_polars()
6264
assert df.to_pandas().values.tolist() == [[90.0, "9", 9.0]]
65+
66+
def test_register_csv(self):
67+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
68+
f.write("name,age,city\n")
69+
f.write("Alice,30,NYC\n")
70+
f.write("Bob,25,LA\n")
71+
f.write("Charlie,35,Chicago\n")
72+
csv_path = f.name
73+
74+
try:
75+
self.ctx.register_csv("people", csv_path)
76+
df = self.ctx.sql("SELECT name, age, city FROM people ORDER BY age").to_pandas()
77+
assert df.values.tolist() == [["Bob", "25", "LA"], ["Alice", "30", "NYC"], ["Charlie", "35", "Chicago"]]
78+
finally:
79+
os.unlink(csv_path)
80+
81+
def test_register_tsv(self):
82+
with tempfile.NamedTemporaryFile(mode="w", suffix=".tsv", delete=False) as f:
83+
f.write("id\tvalue\n")
84+
f.write("1\thello\n")
85+
f.write("2\tworld\n")
86+
tsv_path = f.name
87+
88+
try:
89+
self.ctx.register_tsv("items", tsv_path)
90+
df = self.ctx.sql("SELECT id, value FROM items ORDER BY id").to_pandas()
91+
assert df.values.tolist() == [["1", "hello"], ["2", "world"]]
92+
finally:
93+
os.unlink(tsv_path)

src/bendpy/tests/test_connections.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,27 @@ def register_parquet(self, name, path, pattern=None, connection=None):
4646
def register_csv(self, name, path, pattern=None, connection=None):
4747
if connection:
4848
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
49-
sql = f"create view {name} as select * from '{path}' (file_format => 'csv'{pattern_clause}, connection => '{connection}')"
49+
connection_clause = f", connection => '{connection}'"
50+
# Infer schema first for CSV
51+
infer_conn = f", connection_name => '{connection}'"
52+
self.sql(
53+
f"SELECT column_name FROM infer_schema(location => '{path}', file_format => 'CSV'{infer_conn})"
54+
)
55+
# Use column positions from infer_schema (simulated as 3 columns)
56+
select_clause = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
57+
sql = f"create view {name} as select {select_clause} from '{path}' (file_format => 'csv'{pattern_clause}{connection_clause})"
5058
else:
59+
p = path
60+
if p.startswith("/"):
61+
p = f"fs://{p}"
5162
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
52-
sql = f"create view {name} as select * from '{path}' (file_format => 'csv'{pattern_clause})"
63+
# Infer schema first for CSV
64+
self.sql(
65+
f"SELECT column_name FROM infer_schema(location => '{p}', file_format => 'CSV')"
66+
)
67+
# Use column positions from infer_schema (simulated as 3 columns)
68+
select_clause = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
69+
sql = f"create view {name} as select {select_clause} from '{p}' (file_format => 'csv'{pattern_clause})"
5370
self.sql(sql)
5471

5572
def create_azblob_connection(self, name, endpoint_url, account_name, account_key):
@@ -253,8 +270,15 @@ def test_register_csv_with_connection(self):
253270

254271
self.ctx.register_csv("users", "s3://bucket/users.csv", connection="my_s3")
255272

256-
expected_sql = "create view users as select * from 's3://bucket/users.csv' (file_format => 'csv', connection => 'my_s3')"
257-
mock_sql.assert_called_once_with(expected_sql)
273+
assert mock_sql.call_count == 2
274+
# First call: infer_schema
275+
mock_sql.assert_any_call(
276+
"SELECT column_name FROM infer_schema(location => 's3://bucket/users.csv', file_format => 'CSV', connection_name => 'my_s3')"
277+
)
278+
# Second call: create view with column positions
279+
mock_sql.assert_any_call(
280+
"create view users as select $1 AS `col1`, $2 AS `col2`, $3 AS `col3` from 's3://bucket/users.csv' (file_format => 'csv', connection => 'my_s3')"
281+
)
258282

259283
def test_register_parquet_legacy_mode(self):
260284
with unittest.mock.patch.object(self.ctx, "sql") as mock_sql:
@@ -271,8 +295,15 @@ def test_register_csv_with_pattern_no_connection(self):
271295

272296
self.ctx.register_csv("logs", "/data/logs/", pattern="*.csv")
273297

274-
expected_sql = "create view logs as select * from '/data/logs/' (file_format => 'csv', pattern => '*.csv')"
275-
mock_sql.assert_called_once_with(expected_sql)
298+
assert mock_sql.call_count == 2
299+
# First call: infer_schema with fs:// prefix
300+
mock_sql.assert_any_call(
301+
"SELECT column_name FROM infer_schema(location => 'fs:///data/logs/', file_format => 'CSV')"
302+
)
303+
# Second call: create view with column positions
304+
mock_sql.assert_any_call(
305+
"create view logs as select $1 AS `col1`, $2 AS `col2`, $3 AS `col3` from 'fs:///data/logs/' (file_format => 'csv', pattern => '*.csv')"
306+
)
276307

277308

278309
class TestStages:

0 commit comments

Comments
 (0)