Skip to content

Commit c73eb4a

Browse files
committed
'u' to Use database as default - and star currently used db in explorer
1 parent 3d52f2a commit c73eb4a

File tree

4 files changed

+113
-7
lines changed

4 files changed

+113
-7
lines changed

sqlit/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class SSMSTUI(
292292
Binding("z", "collapse_tree", "Collapse", show=False),
293293
Binding("j", "tree_cursor_down", "Down", show=False),
294294
Binding("k", "tree_cursor_up", "Up", show=False),
295+
Binding("u", "use_database", "Use as default", show=False),
295296
Binding("v", "view_cell", "View cell", show=False),
296297
Binding("u", "edit_cell", "Update cell", show=False),
297298
Binding("h", "results_cursor_left", "Left", show=False),

sqlit/state_machine.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,44 @@ def is_active(self, app: SSMSTUI) -> bool:
601601
return node is not None and _get_node_kind(node) in ("table", "view")
602602

603603

604+
class TreeOnDatabaseState(State):
605+
"""Tree focused on a database node (in multi-database servers)."""
606+
607+
help_category = "Explorer"
608+
609+
def _setup_actions(self) -> None:
610+
self.allows("use_database", key="u", label="Use as default", help="Set as default database")
611+
612+
def get_display_bindings(self, app: SSMSTUI) -> tuple[list[DisplayBinding], list[DisplayBinding]]:
613+
left: list[DisplayBinding] = []
614+
seen: set[str] = set()
615+
616+
left.append(DisplayBinding(key="enter", label="Expand", action="toggle_node"))
617+
seen.add("toggle_node")
618+
left.append(DisplayBinding(key="u", label="Use as default", action="use_database"))
619+
seen.add("use_database")
620+
left.append(DisplayBinding(key="f", label="Refresh", action="refresh_tree"))
621+
seen.add("refresh_tree")
622+
623+
right: list[DisplayBinding] = []
624+
if self.parent:
625+
_, parent_right = self.parent.get_display_bindings(app)
626+
for binding in parent_right:
627+
if binding.action not in seen:
628+
right.append(binding)
629+
seen.add(binding.action)
630+
631+
return left, right
632+
633+
def is_active(self, app: SSMSTUI) -> bool:
634+
if not app.object_tree.has_focus:
635+
return False
636+
node = app.object_tree.cursor_node
637+
return node is not None and _get_node_kind(node) == "database"
638+
639+
604640
class TreeOnFolderState(State):
605-
"""Tree focused on a folder, database, or schema node."""
641+
"""Tree focused on a folder or schema node."""
606642

607643
def _setup_actions(self) -> None:
608644
pass # Just inherits from parent
@@ -630,7 +666,7 @@ def is_active(self, app: SSMSTUI) -> bool:
630666
if not app.object_tree.has_focus:
631667
return False
632668
node = app.object_tree.cursor_node
633-
return node is not None and _get_node_kind(node) in ("folder", "database", "schema")
669+
return node is not None and _get_node_kind(node) in ("folder", "schema")
634670

635671

636672
class TreeOnObjectState(State):
@@ -906,6 +942,7 @@ def __init__(self) -> None:
906942
self.tree_focused = TreeFocusedState(parent=self.main_screen)
907943
self.tree_filter_active = TreeFilterActiveState(parent=self.main_screen)
908944
self.tree_on_connection = TreeOnConnectionState(parent=self.tree_focused)
945+
self.tree_on_database = TreeOnDatabaseState(parent=self.tree_focused)
909946
self.tree_on_table = TreeOnTableState(parent=self.tree_focused)
910947
self.tree_on_folder = TreeOnFolderState(parent=self.tree_focused)
911948
self.tree_on_object = TreeOnObjectState(parent=self.tree_focused)
@@ -924,6 +961,7 @@ def __init__(self) -> None:
924961
self.leader_pending,
925962
self.tree_filter_active, # Before tree_focused (more specific when filter active)
926963
self.tree_on_connection,
964+
self.tree_on_database, # For database nodes (multi-database servers)
927965
self.tree_on_table,
928966
self.tree_on_folder,
929967
self.tree_on_object, # For index/trigger/sequence nodes

sqlit/ui/mixins/query.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ async def _run_query_async(self: AppProtocol, query: str, keep_insert_mode: bool
125125
"""Run query asynchronously using a cancellable dedicated connection."""
126126
import asyncio
127127
import time
128-
from dataclasses import replace
129128

130129
from ...services import CancellableQuery, QueryResult, QueryService
131130
from ...services.query import parse_use_statement
@@ -141,11 +140,9 @@ async def _run_query_async(self: AppProtocol, query: str, keep_insert_mode: bool
141140
# Handle USE database statements
142141
db_name = parse_use_statement(query)
143142
if db_name is not None:
144-
self.current_config = replace(config, database=db_name)
145143
self._stop_query_spinner()
146144
self._display_non_query_result(0, 0)
147-
self.notify(f"Switched to database: {db_name}")
148-
self._update_status_bar()
145+
self.set_default_database(db_name) # type: ignore[attr-defined]
149146
if keep_insert_mode:
150147
self._restore_insert_mode()
151148
return

sqlit/ui/mixins/tree.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,14 @@ def get_conn_label(config: Any, connected: Any = False) -> str:
173173
dbs_node.data = FolderNode(folder_type="databases")
174174

175175
databases = self._run_db_call(adapter.get_databases, self.current_connection)
176+
default_db = self.current_config.database if self.current_config else None
176177
for db_name in databases:
177-
db_node = dbs_node.add(escape_markup(db_name))
178+
# Show default database with star and green text
179+
if default_db and db_name.lower() == default_db.lower():
180+
db_label = f"[#4ADE80]* {escape_markup(db_name)}[/]"
181+
else:
182+
db_label = escape_markup(db_name)
183+
db_node = dbs_node.add(db_label)
178184
db_node.data = DatabaseNode(name=db_name)
179185
db_node.allow_expand = True
180186
self._add_database_object_nodes(db_node, db_name)
@@ -700,3 +706,67 @@ def _display_object_info(self: AppProtocol, object_type: str, info: dict) -> Non
700706
definition = info.get("definition")
701707
if definition:
702708
self.query_input.text = f"/*\n{definition}\n*/"
709+
710+
def set_default_database(self: AppProtocol, db_name: str) -> None:
711+
"""Set the default database for the current connection.
712+
713+
This is the shared function used by both the USE query handler and
714+
the explorer 'Use as default' action.
715+
716+
Args:
717+
db_name: The database name to set as default.
718+
"""
719+
from dataclasses import replace
720+
721+
if not self.current_config:
722+
self.notify("Not connected", severity="error")
723+
return
724+
725+
self.current_config = replace(self.current_config, database=db_name)
726+
self.notify(f"Switched to database: {db_name}")
727+
self._update_status_bar()
728+
self._update_database_labels()
729+
730+
def _update_database_labels(self: AppProtocol) -> None:
731+
"""Update database node labels to show the default database with a star."""
732+
if not self.current_config:
733+
return
734+
735+
default_db = self.current_config.database
736+
737+
# Find the Databases folder and update labels
738+
for conn_node in self.object_tree.root.children:
739+
if self._get_node_kind(conn_node) != "connection":
740+
continue
741+
742+
# Check if this is the active connection
743+
if not (conn_node.data and conn_node.data.config.name == self.current_config.name):
744+
continue
745+
746+
# Find Databases folder
747+
for child in conn_node.children:
748+
if self._get_node_kind(child) == "folder" and child.data.folder_type == "databases":
749+
# Update each database node
750+
for db_node in child.children:
751+
if self._get_node_kind(db_node) == "database":
752+
db_name = db_node.data.name
753+
if default_db and db_name.lower() == default_db.lower():
754+
db_node.set_label(f"[#4ADE80]* {escape_markup(db_name)}[/]")
755+
else:
756+
db_node.set_label(escape_markup(db_name))
757+
break
758+
break
759+
760+
def action_use_database(self: AppProtocol) -> None:
761+
"""Set the selected database as the default for the current connection."""
762+
node = self.object_tree.cursor_node
763+
764+
if not node or self._get_node_kind(node) != "database":
765+
return
766+
767+
if not self.current_connection:
768+
self.notify("Not connected", severity="error")
769+
return
770+
771+
db_name = node.data.name
772+
self.set_default_database(db_name)

0 commit comments

Comments
 (0)