Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 88 additions & 29 deletions src/bendpy/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
use std::sync::Arc;

use databend_common_exception::Result;
use databend_common_expression::BlockEntry;
use databend_common_expression::Column;
use databend_common_meta_app::principal::BUILTIN_ROLE_ACCOUNT_ADMIN;
use databend_common_version::BUILD_INFO;
use databend_query::sessions::BuildInfoRef;
Expand All @@ -28,6 +30,33 @@ use crate::dataframe::default_box_size;
use crate::utils::RUNTIME;
use crate::utils::wait_for_future;

fn resolve_file_path(path: &str) -> String {
if path.contains("://") {
return path.to_owned();
}
if path.starts_with('/') {
return format!("fs://{}", path);
}
format!(
"fs://{}/{}",
std::env::current_dir().unwrap().to_str().unwrap(),
path
)
}

fn extract_string_column(
entry: &BlockEntry,
) -> Option<&databend_common_expression::types::StringColumn> {
match entry {
BlockEntry::Column(Column::String(col)) => Some(col),
BlockEntry::Column(Column::Nullable(n)) => match &n.column {
Column::String(col) => Some(col),
_ => None,
},
_ => None,
}
}

#[pyclass(name = "SessionContext", module = "databend", subclass)]
#[derive(Clone)]
pub(crate) struct PySessionContext {
Expand Down Expand Up @@ -171,41 +200,71 @@ impl PySessionContext {
connection: Option<&str>,
py: Python,
) -> PyResult<()> {
let sql = if let Some(connection_name) = connection {
let pattern_clause = pattern
.map(|p| format!(", pattern => '{}'", p))
.unwrap_or_default();
format!(
"create view {} as select * from '{}' (file_format => '{}'{}, connection => '{}')",
name, path, file_format, pattern_clause, connection_name
)
} else {
let mut path = path.to_owned();
if path.starts_with('/') {
path = format!("fs://{}", path);
}

if !path.contains("://") {
path = format!(
"fs://{}/{}",
std::env::current_dir().unwrap().to_str().unwrap(),
path.as_str()
);
}

let pattern_clause = pattern
.map(|p| format!(", pattern => '{}'", p))
.unwrap_or_default();
format!(
"create view {} as select * from '{}' (file_format => '{}'{})",
name, path, file_format, pattern_clause
)
let file_path = match connection {
Some(_) => path.to_owned(),
None => resolve_file_path(path),
};
let connection_clause = connection
.map(|c| format!(", connection => '{}'", c))
.unwrap_or_default();
let pattern_clause = pattern
.map(|p| format!(", pattern => '{}'", p))
.unwrap_or_default();

let select_clause = match file_format {
"csv" | "tsv" => self.build_column_select(&file_path, file_format, connection, py)?,
_ => "*".to_string(),
Comment on lines 214 to 218

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip TSV schema inference until infer_schema supports TSV

register_tsv() now routes through build_column_select, which emits infer_schema(..., file_format => 'TSV'); however the infer-schema table function currently rejects TSV formats (it only accepts Parquet/CSV/NDJSON in src/query/service/src/table_functions/infer_schema/infer_schema_table.rs), so TSV registration still fails on every call. This means the patch does not actually fix the TSV path and users still cannot register TSV files.

Useful? React with 👍 / 👎.

};

let sql = format!(
"create view {} as select {} from '{}' (file_format => '{}'{}{})",
name, select_clause, file_path, file_format, pattern_clause, connection_clause
);
let _ = self.sql(&sql, py)?.collect(py)?;
Ok(())
}

/// Infer column names via `infer_schema` and build `$1 AS col1, $2 AS col2, ...`.
fn build_column_select(
&mut self,
file_path: &str,
file_format: &str,
connection: Option<&str>,
py: Python,
) -> PyResult<String> {
let conn_clause = connection
.map(|c| format!(", connection_name => '{}'", c))
.unwrap_or_default();
let sql = format!(
"SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{})",
file_path,
file_format.to_uppercase(),
conn_clause
);

let blocks = self.sql(&sql, py)?.collect(py)?;
let col_names: Vec<String> = blocks
.blocks
.iter()
.filter(|b| b.num_rows() > 0)
.filter_map(|b| extract_string_column(b.get_by_offset(0)))
.flat_map(|col| col.iter().map(|s| s.to_string()))
.collect();

if col_names.is_empty() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Could not infer schema: no columns found",
));
}

Ok(col_names
.iter()
.enumerate()
.map(|(i, name)| format!("${} AS `{}`", i + 1, name))
.collect::<Vec<_>>()
.join(", "))
}

#[pyo3(signature = (name, access_key_id, secret_access_key, endpoint_url = None, region = None))]
fn create_s3_connection(
&mut self,
Expand Down
31 changes: 31 additions & 0 deletions src/bendpy/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from databend import SessionContext
import pandas as pd
import polars
import tempfile
import os


class TestBasic:
Expand Down Expand Up @@ -60,3 +62,32 @@ def test_create_insert_select(self):
"select sum(a) x, max(b) y, max(d) z from aa where c"
).to_polars()
assert df.to_pandas().values.tolist() == [[90.0, "9", 9.0]]

def test_register_csv(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
f.write("name,age,city\n")
f.write("Alice,30,NYC\n")
f.write("Bob,25,LA\n")
f.write("Charlie,35,Chicago\n")
csv_path = f.name

try:
self.ctx.register_csv("people", csv_path)
df = self.ctx.sql("SELECT name, age, city FROM people ORDER BY age").to_pandas()
assert df.values.tolist() == [["Bob", "25", "LA"], ["Alice", "30", "NYC"], ["Charlie", "35", "Chicago"]]
finally:
os.unlink(csv_path)

def test_register_tsv(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".tsv", delete=False) as f:
f.write("id\tvalue\n")
f.write("1\thello\n")
f.write("2\tworld\n")
tsv_path = f.name

try:
self.ctx.register_tsv("items", tsv_path)
df = self.ctx.sql("SELECT id, value FROM items ORDER BY id").to_pandas()
assert df.values.tolist() == [["1", "hello"], ["2", "world"]]
finally:
os.unlink(tsv_path)
45 changes: 34 additions & 11 deletions src/bendpy/tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,22 @@ def register_parquet(self, name, path, pattern=None, connection=None):
self.sql(sql)

def register_csv(self, name, path, pattern=None, connection=None):
if connection:
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
sql = f"create view {name} as select * from '{path}' (file_format => 'csv'{pattern_clause}, connection => '{connection}')"
else:
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
sql = f"create view {name} as select * from '{path}' (file_format => 'csv'{pattern_clause})"
self.sql(sql)
self._register_delimited(name, path, "csv", pattern, connection)

def _register_delimited(self, name, path, fmt, pattern=None, connection=None):
"""CSV/TSV: infer schema first, then create view with column positions."""
file_path = path if connection else (f"fs://{path}" if path.startswith("/") else path)
conn_infer = f", connection_name => '{connection}'" if connection else ""
self.sql(
f"SELECT column_name FROM infer_schema(location => '{file_path}', file_format => '{fmt.upper()}'{conn_infer})"
)
# Simulated: infer_schema returns 3 columns
select = "$1 AS `col1`, $2 AS `col2`, $3 AS `col3`"
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
conn_clause = f", connection => '{connection}'" if connection else ""
self.sql(
f"create view {name} as select {select} from '{file_path}' (file_format => '{fmt}'{pattern_clause}{conn_clause})"
)

def create_azblob_connection(self, name, endpoint_url, account_name, account_key):
sql = f"CREATE OR REPLACE CONNECTION {name} STORAGE_TYPE = 'AZBLOB' endpoint_url = '{endpoint_url}' account_name = '{account_name}' account_key = '{account_key}'"
Expand Down Expand Up @@ -253,8 +262,15 @@ def test_register_csv_with_connection(self):

self.ctx.register_csv("users", "s3://bucket/users.csv", connection="my_s3")

expected_sql = "create view users as select * from 's3://bucket/users.csv' (file_format => 'csv', connection => 'my_s3')"
mock_sql.assert_called_once_with(expected_sql)
assert mock_sql.call_count == 2
# First call: infer_schema
mock_sql.assert_any_call(
"SELECT column_name FROM infer_schema(location => 's3://bucket/users.csv', file_format => 'CSV', connection_name => 'my_s3')"
)
# Second call: create view with column positions
mock_sql.assert_any_call(
"create view users as select $1 AS `col1`, $2 AS `col2`, $3 AS `col3` from 's3://bucket/users.csv' (file_format => 'csv', connection => 'my_s3')"
)

def test_register_parquet_legacy_mode(self):
with unittest.mock.patch.object(self.ctx, "sql") as mock_sql:
Expand All @@ -271,8 +287,15 @@ def test_register_csv_with_pattern_no_connection(self):

self.ctx.register_csv("logs", "/data/logs/", pattern="*.csv")

expected_sql = "create view logs as select * from '/data/logs/' (file_format => 'csv', pattern => '*.csv')"
mock_sql.assert_called_once_with(expected_sql)
assert mock_sql.call_count == 2
# First call: infer_schema with fs:// prefix
mock_sql.assert_any_call(
"SELECT column_name FROM infer_schema(location => 'fs:///data/logs/', file_format => 'CSV')"
)
# Second call: create view with column positions
mock_sql.assert_any_call(
"create view logs as select $1 AS `col1`, $2 AS `col2`, $3 AS `col3` from 'fs:///data/logs/' (file_format => 'csv', pattern => '*.csv')"
)


class TestStages:
Expand Down
Loading