Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
467 changes: 79 additions & 388 deletions Cargo.lock

Large diffs are not rendered by default.

14 changes: 4 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@ arrow-flight = { version = "55", features = ["flight-sql-experimental"] }
async-stream = "0.3"
bytes = "1.5"
clap = { version = "4.4", features = ["derive"] }
datafusion = { git = "https://github.com/apache/datafusion", branch = "main", features = [
"pyarrow",
"avro",
] }
datafusion-proto = { git = "https://github.com/apache/datafusion", branch = "main" }
datafusion-substrait = { git = "https://github.com/apache/datafusion", branch = "main" }
datafusion = "48.0.0"
datafusion-proto = "48.0.0"
datafusion-substrait = "48.0.0"
env_logger = "0.11"
futures = "0.3"
itertools = "0.14"
Expand All @@ -54,9 +51,6 @@ log = "0.4"
rand = "0.8"
uuid = { version = "1.6", features = ["v4"] }

serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

object_store = { version = "0.12.0", features = [
"aws",
"gcp",
Expand Down Expand Up @@ -99,4 +93,4 @@ tonic-build = { version = "0.12", default-features = false, features = [
url = "2"

[dev-dependencies]
tempfile = "3.20"
tempfile = "3.20"
8 changes: 4 additions & 4 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ fn main() -> Result<(), String> {

// We don't include the proto files in releases so that downstreams
// do not need to have PROTOC included
if Path::new("src/proto/datafusion_ray.proto").exists() {
if Path::new("src/proto/distributed_datafusion.proto").exists() {
println!("cargo:rerun-if-changed=src/proto/datafusion_common.proto");
println!("cargo:rerun-if-changed=src/proto/datafusion.proto");
println!("cargo:rerun-if-changed=src/proto/datafusion_ray.proto");
println!("cargo:rerun-if-changed=src/proto/distributed_datafusion.proto");
tonic_build::configure()
.extern_path(".datafusion", "::datafusion_proto::protobuf")
.extern_path(".datafusion_common", "::datafusion_proto::protobuf")
.compile_protos(&["src/proto/datafusion_ray.proto"], &["src/proto"])
.compile_protos(&["src/proto/distributed_datafusion.proto"], &["src/proto"])
.map_err(|e| format!("protobuf compilation failed: {e}"))?;
let generated_source_path = out.join("datafusion_ray.protobuf.rs");
let generated_source_path = out.join("distributed_datafusion.protobuf.rs");
let code = std::fs::read_to_string(generated_source_path).unwrap();
let mut file = std::fs::OpenOptions::new()
.write(true)
Expand Down
185 changes: 35 additions & 150 deletions scripts/launch_python_arrowflightsql_client.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ DEFAULT_QUERY_PATH="./tpch/queries/"
# Parse command line arguments
for arg in "$@"; do
case $arg in
query_path=*)
TPCH_QUERY_PATH="${arg#*=}"
;;
*)
echo "Usage: $0 [query_path=/path/to/queries/]"
echo "Example: $0 query_path=./tpch/queries/"
echo "If no argument is provided, default path will be used: $DEFAULT_QUERY_PATH"
exit 1
;;
query_path=*)
TPCH_QUERY_PATH="${arg#*=}"
;;
*)
echo "Usage: $0 [query_path=/path/to/queries/]"
echo "Example: $0 query_path=./tpch/queries/"
echo "If no argument is provided, default path will be used: $DEFAULT_QUERY_PATH"
exit 1
;;
esac
done

Expand Down Expand Up @@ -103,15 +103,16 @@ fi
source .venv/bin/activate

# Install required packages if not already installed
if ! pip show adbc_driver_flightsql > /dev/null 2>&1 || ! pip show duckdb > /dev/null 2>&1; then
if ! pip show adbc_driver_flightsql >/dev/null 2>&1 || ! pip show rich >/dev/null 2>&1; then
echo "Installing required packages..."
pip install adbc_driver_manager adbc_driver_flightsql duckdb pyarrow
pip install adbc_driver_manager adbc_driver_flightsql rich pyarrow
fi

# Create a Python startup script
cat > .python_startup.py << 'EOF'
cat >.python_startup.py <<'EOF'
import adbc_driver_flightsql.dbapi as dbapi
import duckdb
from rich.console import Console
from rich.table import Table
import os
import sys

Expand Down Expand Up @@ -170,53 +171,25 @@ def run_sql(sql_query):
"""
try:
cur.execute(sql_query)
reader = cur.fetch_record_batch()
table = cur.fetch_arrow_table()

rich_table = Table(show_header=True, header_style="bold magenta")

# Add columns based on the PyArrow Table schema
for field in table.schema:
rich_table.add_column(field.name)

# Add rows from the PyArrow Table
for row_index in range(table.num_rows):
row_data = [str(table.column(col_index)[row_index].as_py()) for col_index in range(table.num_columns)]
rich_table.add_row(*row_data)

console = Console()
console.print(rich_table, markup=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Showing table format is a lot better.


# Use basic DuckDB show() - for full output use run_sql_full() or run_sql_raw()
# duckdb.sql("select * from reader").show(max_width=10000)
duckdb.sql("select * from reader").show()
except Exception as e:
print(f"Error executing SQL query: {str(e)}")

def format_plan(plan_text):
"""
Format the plan text by replacing \n with proper indentation
"""
if not plan_text:
return ""

# If the plan_text doesn't contain \\n, return as is (it's likely already formatted)
if '\\n' not in plan_text:
return plan_text

# Split the plan into lines and add proper indentation
lines = plan_text.split('\\n')
formatted_lines = []
indent_level = 0
indent_size = 2

for line in lines:
# Skip empty lines
if not line.strip():
continue

# Count opening and closing parentheses to adjust indentation
open_parens = line.count('(')
close_parens = line.count(')')

# Adjust indent level based on parentheses
if close_parens > open_parens:
indent_level = max(0, indent_level - (close_parens - open_parens))

# Add the line with current indentation
formatted_lines.append(' ' * (indent_level * indent_size) + line.strip())

# Update indent level for next line
if open_parens > close_parens:
indent_level += (open_parens - close_parens)

return '\n'.join(formatted_lines)

def explain_query(query_name):
"""
Run EXPLAIN for a TPC-H query by name (e.g., 'q5' for EXPLAIN q5.sql)
Expand All @@ -225,64 +198,12 @@ def explain_query(query_name):
query_file = os.path.join(tpch_query_path, f"{query_name}.sql")
try:
with open(query_file, 'r') as f:
sql = f.read()
print(f"Executing EXPLAIN for query from {query_file}...")
explain_sql(sql)
sql = "explain " + f.read()
run_sql(sql)
except FileNotFoundError:
print(f"Error: Query file {query_file} not found")
except Exception as e:
print(f"Error executing EXPLAIN: {str(e)}")
import traceback
traceback.print_exc()

def explain_sql(sql_query):
"""
Run EXPLAIN on a given SQL query (passed as a string or variable) and display the formatted output.
"""
try:
print("Executing EXPLAIN...")
cur.execute(f"EXPLAIN {sql_query}")
results = cur.fetchall()
if not results:
print("No explain plan returned")
return

logical_plan = None
physical_plan = None
distributed_plan = None
distributed_stages = None
for row in results:
if row[0] == 'logical_plan':
logical_plan = row[1]
elif row[0] == 'physical_plan':
physical_plan = row[1]
elif row[0] == 'distributed_plan':
distributed_plan = row[1]
elif row[0] == 'distributed_stages':
distributed_stages = row[1]
formatted_logical = format_plan(logical_plan) if logical_plan else "Logical plan not available"
formatted_physical = format_plan(physical_plan) if physical_plan else "Physical plan not available"
formatted_distributed = format_plan(distributed_plan) if distributed_plan else "Distributed plan not available"
formatted_distributed_stages = format_plan(distributed_stages) if distributed_stages else "Distributed stages not available"
print("\nExecution Plan:")
print("=" * 100)
print("Logical Plan:")
print("-" * 100)
print(formatted_logical)
print("\nPhysical Plan:")
print("-" * 100)
print(formatted_physical)
if distributed_plan:
print("\nDistributed Plan:")
print("-" * 100)
print(formatted_distributed)
if distributed_stages:
print("\nDistributed Stages:")
print("-" * 100)
print(formatted_distributed_stages)
print("=" * 100)
except Exception as e:
print(f"Error executing EXPLAIN: {e}")
print(f"Error executing query: {str(e)}")

def explain_analyze_query(query_name):
"""
Expand All @@ -292,47 +213,15 @@ def explain_analyze_query(query_name):
query_file = os.path.join(tpch_query_path, f"{query_name}.sql")
try:
with open(query_file, 'r') as f:
sql = f.read()
print(f"Executing EXPLAIN ANALYZE for query from {query_file}...")
explain_analyze_sql(sql)
sql = "explain analyze " + f.read()
run_sql(sql)
except FileNotFoundError:
print(f"Error: Query file {query_file} not found")
except Exception as e:
print(f"Error executing EXPLAIN ANALYZE: {str(e)}")
import traceback
traceback.print_exc()

def explain_analyze_sql(sql_query):
"""
Run EXPLAIN ANALYZE on a given SQL query (passed as a string or variable) and display the formatted output with execution statistics.
"""
try:
print("Executing EXPLAIN ANALYZE...")
cur.execute(f"EXPLAIN ANALYZE {sql_query}")
results = cur.fetchall()
if not results:
print("No explain analyze plan returned")
return

plan_with_metrics = None

for row in results:
if row[0] == 'Plan with Metrics':
plan_with_metrics = row[1]

if plan_with_metrics:
# EXPLAIN ANALYZE returns execution statistics
formatted_plan_with_metrics = format_plan(plan_with_metrics)
print("\nExecution Plan with Analysis:")
print("=" * 100)
print("Physical Plan with Execution Statistics:")
print("-" * 100)
print(formatted_plan_with_metrics)
print("=" * 100)
except Exception as e:
print(f"Error executing EXPLAIN ANALYZE: {e}")

# Add the run_query function to the global namespace
sys.ps1 = ">>> "
print("\nWelcome to the TPC-H Query Client!")
print("Available commands:")
Expand All @@ -341,11 +230,7 @@ print(" show_query('q5') # Show SQL content of query 5")
print(" run_query('q5') # Run TPC-H query 5")
print(" explain_query('q5') # Show EXPLAIN plan for TPC-H query 5")
print(" explain_analyze_query('q5') # Show EXPLAIN ANALYZE plan for TPC-H query 5")
print(" # Note: EXPLAIN ANALYZE does not work with distributed plans yet")
print(" run_sql('select * from nation') # Run SQL query")
print(" explain_sql('select * from nation') # Show EXPLAIN plan for SQL query")
print(" explain_analyze_sql('select * from nation') # Show EXPLAIN ANALYZE plan for SQL query")
print(" # etc...")
print("\nConnected to database at grpc://localhost:20200")
print("Type 'exit()' to quit\n")
EOF
Expand All @@ -354,4 +239,4 @@ EOF
export TPCH_QUERY_PATH=$TPCH_QUERY_PATH

# Start Python with the startup script
python3 -i .python_startup.py
python3 -i .python_startup.py
Loading
Loading