Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 88 additions & 21 deletions src/bendpy/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
use std::sync::Arc;

use databend_common_exception::Result;
use databend_common_expression::BlockEntry;
use databend_common_expression::Column;
use databend_common_meta_app::principal::BUILTIN_ROLE_ACCOUNT_ADMIN;
use databend_common_version::BUILD_INFO;
use databend_query::sessions::BuildInfoRef;
Expand Down Expand Up @@ -171,41 +173,106 @@ impl PySessionContext {
connection: Option<&str>,
py: Python,
) -> PyResult<()> {
let sql = if let Some(connection_name) = connection {
let pattern_clause = pattern
.map(|p| format!(", pattern => '{}'", p))
.unwrap_or_default();
format!(
"create view {} as select * from '{}' (file_format => '{}'{}, connection => '{}')",
name, path, file_format, pattern_clause, connection_name
// Resolve file path
let (file_path, connection_clause) = if let Some(connection_name) = connection {
(
path.to_owned(),
format!(", connection => '{}'", connection_name),
)
} else {
let mut path = path.to_owned();
if path.starts_with('/') {
path = format!("fs://{}", path);
let mut p = path.to_owned();
if p.starts_with('/') {
p = format!("fs://{}", p);
}

if !path.contains("://") {
path = format!(
if !p.contains("://") {
p = format!(
"fs://{}/{}",
std::env::current_dir().unwrap().to_str().unwrap(),
path.as_str()
p.as_str()
);
}
(p, String::new())
};

let pattern_clause = pattern
.map(|p| format!(", pattern => '{}'", p))
.unwrap_or_default();
format!(
"create view {} as select * from '{}' (file_format => '{}'{})",
name, path, file_format, pattern_clause
)
let pattern_clause = pattern
.map(|p| format!(", pattern => '{}'", p))
.unwrap_or_default();

// For CSV/TSV, use infer_schema to get column positions instead of SELECT *
let select_clause = if file_format == "csv" || file_format == "tsv" {
let col_names =
self.infer_column_names(&file_path, file_format, connection, py)?;
if col_names.is_empty() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Could not infer schema from CSV/TSV file: no columns found",
));
}
col_names
.iter()
.enumerate()
.map(|(i, col_name)| format!("${} AS `{}`", i + 1, col_name))
.collect::<Vec<_>>()
.join(", ")
} else {
"*".to_string()
};

let sql = format!(
"create view {} as select {} from '{}' (file_format => '{}'{}{})",
name, select_clause, file_path, file_format, pattern_clause, connection_clause
);

let _ = self.sql(&sql, py)?.collect(py)?;
Ok(())
}

fn infer_column_names(
&mut self,
file_path: &str,
file_format: &str,
connection: Option<&str>,
py: Python,
) -> PyResult<Vec<String>> {
let connection_clause = connection
.map(|c| format!(", connection_name => '{}'", c))
.unwrap_or_default();

let infer_sql = format!(
"SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{})",
file_path,
file_format.to_uppercase(),
connection_clause
);

let df = self.sql(&infer_sql, py)?;
let blocks = df.collect(py)?;

let mut col_names = Vec::new();
for block in &blocks.blocks {
if block.num_rows() == 0 {
continue;
}
let entry = block.get_by_offset(0);
match entry {
BlockEntry::Column(Column::String(col)) => {
for val in col.iter() {
col_names.push(val.to_string());
}
}
BlockEntry::Column(Column::Nullable(nullable_col)) => {
if let Column::String(col) = &nullable_col.column {
for val in col.iter() {
col_names.push(val.to_string());
}
}
}
_ => {}
}
}

Ok(col_names)
}

#[pyo3(signature = (name, access_key_id, secret_access_key, endpoint_url = None, region = None))]
fn create_s3_connection(
&mut self,
Expand Down
31 changes: 31 additions & 0 deletions src/bendpy/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from databend import SessionContext
import pandas as pd
import polars
import tempfile
import os


class TestBasic:
Expand Down Expand Up @@ -60,3 +62,32 @@ def test_create_insert_select(self):
"select sum(a) x, max(b) y, max(d) z from aa where c"
).to_polars()
assert df.to_pandas().values.tolist() == [[90.0, "9", 9.0]]

def test_register_csv(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
f.write("name,age,city\n")
f.write("Alice,30,NYC\n")
f.write("Bob,25,LA\n")
f.write("Charlie,35,Chicago\n")
csv_path = f.name

try:
self.ctx.register_csv("people", csv_path)
df = self.ctx.sql("SELECT name, age, city FROM people ORDER BY age").to_pandas()
assert df.values.tolist() == [["Bob", "25", "LA"], ["Alice", "30", "NYC"], ["Charlie", "35", "Chicago"]]
finally:
os.unlink(csv_path)

def test_register_tsv(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".tsv", delete=False) as f:
f.write("id\tvalue\n")
f.write("1\thello\n")
f.write("2\tworld\n")
tsv_path = f.name

try:
self.ctx.register_tsv("items", tsv_path)
df = self.ctx.sql("SELECT id, value FROM items ORDER BY id").to_pandas()
assert df.values.tolist() == [["1", "hello"], ["2", "world"]]
finally:
os.unlink(tsv_path)
43 changes: 37 additions & 6 deletions src/bendpy/tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,27 @@ def register_parquet(self, name, path, pattern=None, connection=None):
def register_csv(self, name, path, pattern=None, connection=None):
if connection:
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
sql = f"create view {name} as select * from '{path}' (file_format => 'csv'{pattern_clause}, connection => '{connection}')"
connection_clause = f", connection => '{connection}'"
# Infer schema first for CSV
infer_conn = f", connection_name => '{connection}'"
self.sql(
f"SELECT column_name FROM infer_schema(location => '{path}', file_format => 'CSV'{infer_conn})"
)
# Use column positions from infer_schema (simulated as 3 columns)
select_clause = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
sql = f"create view {name} as select {select_clause} from '{path}' (file_format => 'csv'{pattern_clause}{connection_clause})"
else:
p = path
if p.startswith("/"):
p = f"fs://{p}"
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
sql = f"create view {name} as select * from '{path}' (file_format => 'csv'{pattern_clause})"
# Infer schema first for CSV
self.sql(
f"SELECT column_name FROM infer_schema(location => '{p}', file_format => 'CSV')"
)
# Use column positions from infer_schema (simulated as 3 columns)
select_clause = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
sql = f"create view {name} as select {select_clause} from '{p}' (file_format => 'csv'{pattern_clause})"
self.sql(sql)

def create_azblob_connection(self, name, endpoint_url, account_name, account_key):
Expand Down Expand Up @@ -253,8 +270,15 @@ def test_register_csv_with_connection(self):

self.ctx.register_csv("users", "s3://bucket/users.csv", connection="my_s3")

expected_sql = "create view users as select * from 's3://bucket/users.csv' (file_format => 'csv', connection => 'my_s3')"
mock_sql.assert_called_once_with(expected_sql)
assert mock_sql.call_count == 2
# First call: infer_schema
mock_sql.assert_any_call(
"SELECT column_name FROM infer_schema(location => 's3://bucket/users.csv', file_format => 'CSV', connection_name => 'my_s3')"
)
# Second call: create view with column positions
mock_sql.assert_any_call(
"create view users as select $1 AS `col1`, $2 AS `col2`, $3 AS `col3` from 's3://bucket/users.csv' (file_format => 'csv', connection => 'my_s3')"
)

def test_register_parquet_legacy_mode(self):
with unittest.mock.patch.object(self.ctx, "sql") as mock_sql:
Expand All @@ -271,8 +295,15 @@ def test_register_csv_with_pattern_no_connection(self):

self.ctx.register_csv("logs", "/data/logs/", pattern="*.csv")

expected_sql = "create view logs as select * from '/data/logs/' (file_format => 'csv', pattern => '*.csv')"
mock_sql.assert_called_once_with(expected_sql)
assert mock_sql.call_count == 2
# First call: infer_schema with fs:// prefix
mock_sql.assert_any_call(
"SELECT column_name FROM infer_schema(location => 'fs:///data/logs/', file_format => 'CSV')"
)
# Second call: create view with column positions
mock_sql.assert_any_call(
"create view logs as select $1 AS `col1`, $2 AS `col2`, $3 AS `col3` from 'fs:///data/logs/' (file_format => 'csv', pattern => '*.csv')"
)


class TestStages:
Expand Down
Loading