Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Doc/whatsnew/3.15.rst
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,9 @@ sqlite3
details.
(Contributed by Stan Ulbrych and Łukasz Langa in :gh:`133461`.)

* Table, index, trigger, view, column, function, and schema completion on <tab>.
(Contributed by Long Tan in :gh:`136101`.)


ssl
---
Expand Down
2 changes: 1 addition & 1 deletion Lib/sqlite3/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def main(*args):
execute(con, args.sql, suppress_errors=False, theme=theme)
else:
# No SQL provided; start the REPL.
with completer():
with completer(con):
console = SqliteInteractiveConsole(con, use_color=True)
console.interact(banner, exitmsg="")
finally:
Expand Down
78 changes: 71 additions & 7 deletions Lib/sqlite3/_completer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from _sqlite3 import OperationalError
from contextlib import contextmanager

try:
Expand All @@ -10,32 +11,95 @@
_completion_matches = []


def _complete(text, state):
def _complete(con, text, state):
global _completion_matches

if state == 0:
if text.startswith('.'):
_completion_matches = [c for c in CLI_COMMANDS if c.startswith(text)]
if text.startswith("."):
_completion_matches = [
c + " " for c in CLI_COMMANDS if c.startswith(text)
]
else:
text_upper = text.upper()
_completion_matches = [c for c in SQLITE_KEYWORDS if c.startswith(text_upper)]
_completion_matches = [
c + " " for c in SQLITE_KEYWORDS if c.startswith(text_upper)
]

cursor = con.cursor()
schemata = tuple(row[1] for row
in cursor.execute("PRAGMA database_list"))
# tables, indexes, triggers, and views
# escape '_' which can appear in attached database names
select_clauses = (
f"""\
SELECT name || ' ' FROM \"{schema}\".sqlite_master
WHERE name LIKE REPLACE(:text, '_', '^_') || '%' ESCAPE '^'"""
for schema in schemata
)
_completion_matches.extend(
row[0]
for row in cursor.execute(
" UNION ".join(select_clauses), {"text": text}
)
)
# columns
try:
select_clauses = (
f"""\
SELECT pti.name || ' ' FROM "{schema}".sqlite_master AS sm
JOIN pragma_table_xinfo(sm.name,'{schema}') AS pti
WHERE sm.type='table' AND
pti.name LIKE REPLACE(:text, '_', '^_') || '%' ESCAPE '^'"""
for schema in schemata
)
_completion_matches.extend(
row[0]
for row in cursor.execute(
" UNION ".join(select_clauses), {"text": text}
)
)
except OperationalError:
# skip on SQLite<3.16.0 where pragma table-valued function is
# not supported yet
pass
# functions
try:
_completion_matches.extend(
row[0] for row in cursor.execute("""\
SELECT DISTINCT UPPER(name) || '('
FROM pragma_function_list()
WHERE name NOT IN ('->', '->>') AND
name LIKE REPLACE(:text, '_', '^_') || '%' ESCAPE '^'""",
{"text": text},
)
)
except OperationalError:
# skip on SQLite<3.30.0 where function_list is not supported yet
pass
# schemata
text_lower = text.lower()
_completion_matches.extend(c for c in schemata
if c.lower().startswith(text_lower))
_completion_matches = sorted(set(_completion_matches))
try:
return _completion_matches[state] + " "
return _completion_matches[state]
except IndexError:
return None


@contextmanager
def completer():
def completer(con):
try:
import readline
except ImportError:
yield
return

old_completer = readline.get_completer()
def complete(text, state):
return _complete(con, text, state)
try:
readline.set_completer(_complete)
readline.set_completer(complete)
if readline.backend == "editline":
# libedit uses "^I" instead of "tab"
command_string = "bind ^I rl_complete"
Expand Down
128 changes: 121 additions & 7 deletions Lib/test/test_sqlite3/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,6 @@ class Completion(unittest.TestCase):

@classmethod
def setUpClass(cls):
_sqlite3 = import_module("_sqlite3")
if not hasattr(_sqlite3, "SQLITE_KEYWORDS"):
raise unittest.SkipTest("unable to determine SQLite keywords")

readline = import_module("readline")
if readline.backend == "editline":
raise unittest.SkipTest("libedit readline is not supported")
Expand All @@ -229,12 +225,24 @@ def write_input(self, input_, env=None):
import readline
from sqlite3.__main__ import main

# Configure readline to ...:
# - hide control sequences surrounding each candidate
# - hide "Display all xxx possibilities? (y or n)"
# - show candidates one per line
readline.parse_and_bind("set colored-completion-prefix off")
readline.parse_and_bind("set completion-query-items 0")
readline.parse_and_bind("set page-completions off")
readline.parse_and_bind("set completion-display-width 0")

main()
""")
return run_pty(script, input_, env)

def test_complete_sql_keywords(self):
_sqlite3 = import_module("_sqlite3")
if not hasattr(_sqlite3, "SQLITE_KEYWORDS"):
raise unittest.SkipTest("unable to determine SQLite keywords")

# List candidates starting with 'S', there should be multiple matches.
input_ = b"S\t\tEL\t 1;\n.quit\n"
output = self.write_input(input_)
Expand All @@ -254,6 +262,114 @@ def test_complete_sql_keywords(self):
output = self.write_input(input_)
self.assertIn(b".version", output)

def test_complete_table_indexes_triggers_views(self):
input_ = textwrap.dedent("""\
CREATE TABLE _Table (id);
CREATE INDEX _Index ON _table (id);
CREATE TRIGGER _Trigger BEFORE INSERT
ON _Table BEGIN SELECT 1; END;
CREATE VIEW _View AS SELECT 1;

CREATE TEMP TABLE _Temp_table (id);
CREATE INDEX temp._Temp_index ON _Temp_table (id);
CREATE TEMP TRIGGER _Temp_trigger BEFORE INSERT
ON _Table BEGIN SELECT 1; END;
CREATE TEMP VIEW _Temp_view AS SELECT 1;

ATTACH ':memory:' AS attached;
CREATE TABLE attached._Attached_table (id);
CREATE INDEX attached._Attached_index ON _Attached_table (id);
CREATE TRIGGER attached._Attached_trigger BEFORE INSERT
ON _Attached_table BEGIN SELECT 1; END;
CREATE VIEW attached._Attached_view AS SELECT 1;

SELECT id FROM _\t\tta\t;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last \t (after _a) isn't tested. Did you mean to add another assertion with start, end = indices[-2], indices[-1]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review! I will add it.

.quit\n""").encode()
output = self.write_input(input_)
lines = output.decode().splitlines()
indices = [i for i, line in enumerate(lines)
if line.startswith(self.PS1)]
start, end = indices[-3], indices[-2]
candidates = [l.strip() for l in lines[start+1:end]]
self.assertEqual(candidates,
[
"_Attached_index",
"_Attached_table",
"_Attached_trigger",
"_Attached_view",
"_Index",
"_Table",
"_Temp_index",
"_Temp_table",
"_Temp_trigger",
"_Temp_view",
"_Trigger",
"_View",
],
)

@unittest.skipIf(sqlite3.sqlite_version_info < (3, 16, 0),
"PRAGMA table-valued function is not available until "
"SQLite 3.16.0")
def test_complete_columns(self):
input_ = textwrap.dedent("""\
CREATE TABLE _table (_col_table);
CREATE TEMP TABLE _temp_table (_col_temp);
ATTACH ':memory:' AS attached;
CREATE TABLE attached._attached_table (_col_attached);

SELECT _col_\t\tta\tFROM _table;
.quit\n""").encode()
output = self.write_input(input_)
lines = output.decode().splitlines()
indices = [
i for i, line in enumerate(lines) if line.startswith(self.PS1)
]
start, end = indices[-3], indices[-2]
candidates = [l.strip() for l in lines[start+1:end]]

self.assertEqual(
candidates, ["_col_attached", "_col_table", "_col_temp"]
)

@unittest.skipIf(sqlite3.sqlite_version_info < (3, 30, 0),
"PRAGMA function_list is not available until "
"SQLite 3.30.0")
def test_complete_functions(self):
input_ = b"SELECT AV\t1);\n.quit\n"
output = self.write_input(input_)
self.assertIn(b"AVG(1);", output)
self.assertIn(b"(1.0,)", output)

# Functions are completed in upper case for even lower case user input.
input_ = b"SELECT av\t1);\n.quit\n"
output = self.write_input(input_)
self.assertIn(b"AVG(1);", output)
self.assertIn(b"(1.0,)", output)

def test_complete_schemata(self):
input_ = textwrap.dedent("""\
ATTACH ':memory:' AS MixedCase;
-- Test '_' is escaped in Like pattern filtering
ATTACH ':memory:' AS _underscore;
-- Let database_list pragma have a 'temp' schema entry
CREATE TEMP TABLE _table (id);

SELECT * FROM \t\tmIX\t.sqlite_master;
SELECT * FROM _und\t.sqlite_master;
.quit\n""").encode()
output = self.write_input(input_)
lines = output.decode().splitlines()
indices = [
i for i, line in enumerate(lines) if line.startswith(self.PS1)
]
start, end = indices[-4], indices[-3]
candidates = [l.strip() for l in lines[start+1:end]]
self.assertIn("MixedCase", candidates)
self.assertIn("_underscore", candidates)
self.assertIn("main", candidates)
self.assertIn("temp", candidates)

@unittest.skipIf(sys.platform.startswith("freebsd"),
"Two actual tabs are inserted when there are no matching"
" completions in the pseudo-terminal opened by run_pty()"
Expand All @@ -274,8 +390,6 @@ def test_complete_no_match(self):
self.assertEqual(line_num, len(lines))

def test_complete_no_input(self):
from _sqlite3 import SQLITE_KEYWORDS

script = textwrap.dedent("""
import readline
from sqlite3.__main__ import main
Expand Down Expand Up @@ -306,7 +420,7 @@ def test_complete_no_input(self):
self.assertEqual(len(indices), 2)
start, end = indices
candidates = [l.strip() for l in lines[start+1:end]]
self.assertEqual(candidates, sorted(SQLITE_KEYWORDS))
self.assertEqual(candidates, sorted(candidates))
except:
if verbose:
print(' PTY output: '.center(30, '-'))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Support table, index, trigger, view, column, function, and schema completion
for :mod:`sqlite3`'s :ref:`command-line interface <sqlite3-cli>`.
Loading