Skip to content

Commit 00ac55a

Browse files
authored
Merge pull request #42 from datafusion-contrib/robtandy/explainanalyze
Explain Analyze + Refactor
2 parents eed7176 + e0db0e3 commit 00ac55a

30 files changed

+1934
-2569
lines changed

Cargo.lock

Lines changed: 79 additions & 388 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,9 @@ arrow-flight = { version = "55", features = ["flight-sql-experimental"] }
3939
async-stream = "0.3"
4040
bytes = "1.5"
4141
clap = { version = "4.4", features = ["derive"] }
42-
datafusion = { git = "https://github.com/apache/datafusion", branch = "main", features = [
43-
"pyarrow",
44-
"avro",
45-
] }
46-
datafusion-proto = { git = "https://github.com/apache/datafusion", branch = "main" }
47-
datafusion-substrait = { git = "https://github.com/apache/datafusion", branch = "main" }
42+
datafusion = "48.0.0"
43+
datafusion-proto = "48.0.0"
44+
datafusion-substrait = "48.0.0"
4845
env_logger = "0.11"
4946
futures = "0.3"
5047
itertools = "0.14"
@@ -54,9 +51,6 @@ log = "0.4"
5451
rand = "0.8"
5552
uuid = { version = "1.6", features = ["v4"] }
5653

57-
serde = { version = "1.0", features = ["derive"] }
58-
serde_json = "1.0"
59-
6054
object_store = { version = "0.12.0", features = [
6155
"aws",
6256
"gcp",
@@ -99,4 +93,4 @@ tonic-build = { version = "0.12", default-features = false, features = [
9993
url = "2"
10094

10195
[dev-dependencies]
102-
tempfile = "3.20"
96+
tempfile = "3.20"

build.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ fn main() -> Result<(), String> {
3232

3333
// We don't include the proto files in releases so that downstreams
3434
// do not need to have PROTOC included
35-
if Path::new("src/proto/datafusion_ray.proto").exists() {
35+
if Path::new("src/proto/distributed_datafusion.proto").exists() {
3636
println!("cargo:rerun-if-changed=src/proto/datafusion_common.proto");
3737
println!("cargo:rerun-if-changed=src/proto/datafusion.proto");
38-
println!("cargo:rerun-if-changed=src/proto/datafusion_ray.proto");
38+
println!("cargo:rerun-if-changed=src/proto/distributed_datafusion.proto");
3939
tonic_build::configure()
4040
.extern_path(".datafusion", "::datafusion_proto::protobuf")
4141
.extern_path(".datafusion_common", "::datafusion_proto::protobuf")
42-
.compile_protos(&["src/proto/datafusion_ray.proto"], &["src/proto"])
42+
.compile_protos(&["src/proto/distributed_datafusion.proto"], &["src/proto"])
4343
.map_err(|e| format!("protobuf compilation failed: {e}"))?;
44-
let generated_source_path = out.join("datafusion_ray.protobuf.rs");
44+
let generated_source_path = out.join("distributed_datafusion.protobuf.rs");
4545
let code = std::fs::read_to_string(generated_source_path).unwrap();
4646
let mut file = std::fs::OpenOptions::new()
4747
.write(true)

scripts/launch_python_arrowflightsql_client.sh

Lines changed: 35 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ DEFAULT_QUERY_PATH="./tpch/queries/"
6363
# Parse command line arguments
6464
for arg in "$@"; do
6565
case $arg in
66-
query_path=*)
67-
TPCH_QUERY_PATH="${arg#*=}"
68-
;;
69-
*)
70-
echo "Usage: $0 [query_path=/path/to/queries/]"
71-
echo "Example: $0 query_path=./tpch/queries/"
72-
echo "If no argument is provided, default path will be used: $DEFAULT_QUERY_PATH"
73-
exit 1
74-
;;
66+
query_path=*)
67+
TPCH_QUERY_PATH="${arg#*=}"
68+
;;
69+
*)
70+
echo "Usage: $0 [query_path=/path/to/queries/]"
71+
echo "Example: $0 query_path=./tpch/queries/"
72+
echo "If no argument is provided, default path will be used: $DEFAULT_QUERY_PATH"
73+
exit 1
74+
;;
7575
esac
7676
done
7777

@@ -103,15 +103,16 @@ fi
103103
source .venv/bin/activate
104104

105105
# Install required packages if not already installed
106-
if ! pip show adbc_driver_flightsql > /dev/null 2>&1 || ! pip show duckdb > /dev/null 2>&1; then
106+
if ! pip show adbc_driver_flightsql >/dev/null 2>&1 || ! pip show rich >/dev/null 2>&1; then
107107
echo "Installing required packages..."
108-
pip install adbc_driver_manager adbc_driver_flightsql duckdb pyarrow
108+
pip install adbc_driver_manager adbc_driver_flightsql rich pyarrow
109109
fi
110110

111111
# Create a Python startup script
112-
cat > .python_startup.py << 'EOF'
112+
cat >.python_startup.py <<'EOF'
113113
import adbc_driver_flightsql.dbapi as dbapi
114-
import duckdb
114+
from rich.console import Console
115+
from rich.table import Table
115116
import os
116117
import sys
117118
@@ -170,53 +171,25 @@ def run_sql(sql_query):
170171
"""
171172
try:
172173
cur.execute(sql_query)
173-
reader = cur.fetch_record_batch()
174+
table = cur.fetch_arrow_table()
175+
176+
rich_table = Table(show_header=True, header_style="bold magenta")
177+
178+
# Add columns based on the PyArrow Table schema
179+
for field in table.schema:
180+
rich_table.add_column(field.name)
181+
182+
# Add rows from the PyArrow Table
183+
for row_index in range(table.num_rows):
184+
row_data = [str(table.column(col_index)[row_index].as_py()) for col_index in range(table.num_columns)]
185+
rich_table.add_row(*row_data)
186+
187+
console = Console()
188+
console.print(rich_table, markup=False)
174189
175-
# Use basic DuckDB show() - for full output use run_sql_full() or run_sql_raw()
176-
# duckdb.sql("select * from reader").show(max_width=10000)
177-
duckdb.sql("select * from reader").show()
178190
except Exception as e:
179191
print(f"Error executing SQL query: {str(e)}")
180192
181-
def format_plan(plan_text):
182-
"""
183-
Format the plan text by replacing \n with proper indentation
184-
"""
185-
if not plan_text:
186-
return ""
187-
188-
# If the plan_text doesn't contain \\n, return as is (it's likely already formatted)
189-
if '\\n' not in plan_text:
190-
return plan_text
191-
192-
# Split the plan into lines and add proper indentation
193-
lines = plan_text.split('\\n')
194-
formatted_lines = []
195-
indent_level = 0
196-
indent_size = 2
197-
198-
for line in lines:
199-
# Skip empty lines
200-
if not line.strip():
201-
continue
202-
203-
# Count opening and closing parentheses to adjust indentation
204-
open_parens = line.count('(')
205-
close_parens = line.count(')')
206-
207-
# Adjust indent level based on parentheses
208-
if close_parens > open_parens:
209-
indent_level = max(0, indent_level - (close_parens - open_parens))
210-
211-
# Add the line with current indentation
212-
formatted_lines.append(' ' * (indent_level * indent_size) + line.strip())
213-
214-
# Update indent level for next line
215-
if open_parens > close_parens:
216-
indent_level += (open_parens - close_parens)
217-
218-
return '\n'.join(formatted_lines)
219-
220193
def explain_query(query_name):
221194
"""
222195
Run EXPLAIN for a TPC-H query by name (e.g., 'q5' for EXPLAIN q5.sql)
@@ -225,64 +198,12 @@ def explain_query(query_name):
225198
query_file = os.path.join(tpch_query_path, f"{query_name}.sql")
226199
try:
227200
with open(query_file, 'r') as f:
228-
sql = f.read()
229-
print(f"Executing EXPLAIN for query from {query_file}...")
230-
explain_sql(sql)
201+
sql = "explain " + f.read()
202+
run_sql(sql)
231203
except FileNotFoundError:
232204
print(f"Error: Query file {query_file} not found")
233205
except Exception as e:
234-
print(f"Error executing EXPLAIN: {str(e)}")
235-
import traceback
236-
traceback.print_exc()
237-
238-
def explain_sql(sql_query):
239-
"""
240-
Run EXPLAIN on a given SQL query (passed as a string or variable) and display the formatted output.
241-
"""
242-
try:
243-
print("Executing EXPLAIN...")
244-
cur.execute(f"EXPLAIN {sql_query}")
245-
results = cur.fetchall()
246-
if not results:
247-
print("No explain plan returned")
248-
return
249-
250-
logical_plan = None
251-
physical_plan = None
252-
distributed_plan = None
253-
distributed_stages = None
254-
for row in results:
255-
if row[0] == 'logical_plan':
256-
logical_plan = row[1]
257-
elif row[0] == 'physical_plan':
258-
physical_plan = row[1]
259-
elif row[0] == 'distributed_plan':
260-
distributed_plan = row[1]
261-
elif row[0] == 'distributed_stages':
262-
distributed_stages = row[1]
263-
formatted_logical = format_plan(logical_plan) if logical_plan else "Logical plan not available"
264-
formatted_physical = format_plan(physical_plan) if physical_plan else "Physical plan not available"
265-
formatted_distributed = format_plan(distributed_plan) if distributed_plan else "Distributed plan not available"
266-
formatted_distributed_stages = format_plan(distributed_stages) if distributed_stages else "Distributed stages not available"
267-
print("\nExecution Plan:")
268-
print("=" * 100)
269-
print("Logical Plan:")
270-
print("-" * 100)
271-
print(formatted_logical)
272-
print("\nPhysical Plan:")
273-
print("-" * 100)
274-
print(formatted_physical)
275-
if distributed_plan:
276-
print("\nDistributed Plan:")
277-
print("-" * 100)
278-
print(formatted_distributed)
279-
if distributed_stages:
280-
print("\nDistributed Stages:")
281-
print("-" * 100)
282-
print(formatted_distributed_stages)
283-
print("=" * 100)
284-
except Exception as e:
285-
print(f"Error executing EXPLAIN: {e}")
206+
print(f"Error executing query: {str(e)}")
286207
287208
def explain_analyze_query(query_name):
288209
"""
@@ -292,47 +213,15 @@ def explain_analyze_query(query_name):
292213
query_file = os.path.join(tpch_query_path, f"{query_name}.sql")
293214
try:
294215
with open(query_file, 'r') as f:
295-
sql = f.read()
296-
print(f"Executing EXPLAIN ANALYZE for query from {query_file}...")
297-
explain_analyze_sql(sql)
216+
sql = "explain analyze " + f.read()
217+
run_sql(sql)
298218
except FileNotFoundError:
299219
print(f"Error: Query file {query_file} not found")
300220
except Exception as e:
301221
print(f"Error executing EXPLAIN ANALYZE: {str(e)}")
302222
import traceback
303223
traceback.print_exc()
304224
305-
def explain_analyze_sql(sql_query):
306-
"""
307-
Run EXPLAIN ANALYZE on a given SQL query (passed as a string or variable) and display the formatted output with execution statistics.
308-
"""
309-
try:
310-
print("Executing EXPLAIN ANALYZE...")
311-
cur.execute(f"EXPLAIN ANALYZE {sql_query}")
312-
results = cur.fetchall()
313-
if not results:
314-
print("No explain analyze plan returned")
315-
return
316-
317-
plan_with_metrics = None
318-
319-
for row in results:
320-
if row[0] == 'Plan with Metrics':
321-
plan_with_metrics = row[1]
322-
323-
if plan_with_metrics:
324-
# EXPLAIN ANALYZE returns execution statistics
325-
formatted_plan_with_metrics = format_plan(plan_with_metrics)
326-
print("\nExecution Plan with Analysis:")
327-
print("=" * 100)
328-
print("Physical Plan with Execution Statistics:")
329-
print("-" * 100)
330-
print(formatted_plan_with_metrics)
331-
print("=" * 100)
332-
except Exception as e:
333-
print(f"Error executing EXPLAIN ANALYZE: {e}")
334-
335-
# Add the run_query function to the global namespace
336225
sys.ps1 = ">>> "
337226
print("\nWelcome to the TPC-H Query Client!")
338227
print("Available commands:")
@@ -341,11 +230,7 @@ print(" show_query('q5') # Show SQL content of query 5")
341230
print(" run_query('q5') # Run TPC-H query 5")
342231
print(" explain_query('q5') # Show EXPLAIN plan for TPC-H query 5")
343232
print(" explain_analyze_query('q5') # Show EXPLAIN ANALYZE plan for TPC-H query 5")
344-
print(" # Note: EXPLAIN ANALYZE does not work with distributed plans yet")
345233
print(" run_sql('select * from nation') # Run SQL query")
346-
print(" explain_sql('select * from nation') # Show EXPLAIN plan for SQL query")
347-
print(" explain_analyze_sql('select * from nation') # Show EXPLAIN ANALYZE plan for SQL query")
348-
print(" # etc...")
349234
print("\nConnected to database at grpc://localhost:20200")
350235
print("Type 'exit()' to quit\n")
351236
EOF
@@ -354,4 +239,4 @@ EOF
354239
export TPCH_QUERY_PATH=$TPCH_QUERY_PATH
355240

356241
# Start Python with the startup script
357-
python3 -i .python_startup.py
242+
python3 -i .python_startup.py

0 commit comments

Comments
 (0)