Skip to content

Commit 73f9fba

Browse files
fix: resolve CLI -h conflict with --help
- Remove -h shorthand for --host (conflicts with argparse --help) - Add module-level docstring with usage examples - Improve function docstring with NumPy style - Add explicit error handling for invalid schema format - Improve banner message with version and usage hint - Use modern type hints (list[str] | None) - Fix locals() issue: explicitly include dj in REPL namespace Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 2665f5b commit 73f9fba

File tree

1 file changed

+78
-33
lines changed

1 file changed

+78
-33
lines changed

src/datajoint/cli.py

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,119 @@
1+
"""
2+
DataJoint command-line interface.
3+
4+
Provides a Python REPL with DataJoint pre-loaded and optional schema access.
5+
6+
Usage::
7+
8+
# Start REPL with database credentials
9+
dj --user root --password secret --host localhost:3306
10+
11+
# Load schemas as virtual modules
12+
dj -s my_lab:lab -s my_analysis:analysis
13+
14+
# In the REPL
15+
>>> lab.Subject.to_dicts()
16+
>>> dj.Diagram(lab.schema)
17+
"""
18+
19+
from __future__ import annotations
20+
121
import argparse
222
from code import interact
323
from collections import ChainMap
424

525
import datajoint as dj
626

727

8-
def cli(args: list = None):
28+
def cli(args: list[str] | None = None) -> None:
929
"""
10-
Console interface for DataJoint Python
30+
DataJoint command-line interface.
31+
32+
Starts an interactive Python REPL with DataJoint imported and configured.
33+
Optionally loads database schemas as virtual modules for quick exploration.
34+
35+
Parameters
36+
----------
37+
args : list[str], optional
38+
Command-line arguments. If None, reads from sys.argv.
1139
12-
:param args: List of arguments to be passed in, defaults to reading stdin
13-
:type args: list, optional
40+
Examples
41+
--------
42+
From the command line::
43+
44+
$ dj --host localhost:3306 --user root --password secret
45+
$ dj -s my_lab:lab -s my_analysis:analysis
46+
47+
Programmatically::
48+
49+
>>> from datajoint.cli import cli
50+
>>> cli(["--version"])
1451
"""
1552
parser = argparse.ArgumentParser(
16-
prog="datajoint",
17-
description="DataJoint console interface.",
18-
conflict_handler="resolve",
53+
prog="dj",
54+
description="DataJoint interactive console. Start a Python REPL with DataJoint pre-loaded.",
55+
epilog="Example: dj -s my_lab:lab --host localhost:3306",
56+
)
57+
parser.add_argument(
58+
"-V", "--version",
59+
action="version",
60+
version=f"{dj.__name__} {dj.__version__}",
1961
)
20-
parser.add_argument("-V", "--version", action="version", version=f"{dj.__name__} {dj.__version__}")
2162
parser.add_argument(
22-
"-u",
23-
"--user",
63+
"-u", "--user",
2464
type=str,
25-
default=dj.config["database.user"],
26-
required=False,
27-
help="Datajoint username",
65+
default=None,
66+
help="Database username (default: from config)",
2867
)
2968
parser.add_argument(
30-
"-p",
31-
"--password",
69+
"-p", "--password",
3270
type=str,
33-
default=dj.config["database.password"],
34-
required=False,
35-
help="Datajoint password",
71+
default=None,
72+
help="Database password (default: from config)",
3673
)
3774
parser.add_argument(
38-
"-h",
3975
"--host",
4076
type=str,
41-
default=dj.config["database.host"],
42-
required=False,
43-
help="Datajoint host",
77+
default=None,
78+
help="Database host as host:port (default: from config)",
4479
)
4580
parser.add_argument(
46-
"-s",
47-
"--schemas",
81+
"-s", "--schemas",
4882
nargs="+",
4983
type=str,
50-
required=False,
51-
help="A list of virtual module mappings in `db:schema ...` format",
84+
metavar="DB:ALIAS",
85+
help="Load schemas as virtual modules. Format: schema_name:alias",
5286
)
87+
5388
kwargs = vars(parser.parse_args(args))
54-
mods = {}
89+
90+
# Apply credentials to config
5591
if kwargs["user"]:
5692
dj.config["database.user"] = kwargs["user"]
5793
if kwargs["password"]:
5894
dj.config["database.password"] = kwargs["password"]
5995
if kwargs["host"]:
6096
dj.config["database.host"] = kwargs["host"]
97+
98+
# Load requested schemas
99+
mods: dict[str, dj.VirtualModule] = {}
61100
if kwargs["schemas"]:
62101
for vm in kwargs["schemas"]:
63-
d, m = vm.split(":")
64-
mods[m] = dj.VirtualModule(m, d)
102+
if ":" not in vm:
103+
parser.error(f"Invalid schema format '{vm}'. Use schema_name:alias")
104+
schema_name, alias = vm.split(":", 1)
105+
mods[alias] = dj.VirtualModule(alias, schema_name)
65106

66-
banner = "dj repl\n"
107+
# Build banner
108+
banner = f"DataJoint {dj.__version__} REPL\n"
109+
banner += "Type 'dj.' and press Tab for available functions.\n"
67110
if mods:
68-
modstr = "\n".join(" - {}".format(m) for m in mods)
69-
banner += "\nschema modules:\n\n" + modstr + "\n"
70-
interact(banner, local=dict(ChainMap(mods, locals(), globals())))
111+
banner += "\nLoaded schemas:\n"
112+
for alias in mods:
113+
banner += f" {alias} -> {mods[alias].schema.database}\n"
71114

115+
# Start interactive session
116+
interact(banner, local=dict(ChainMap(mods, {"dj": dj}, globals())))
72117
raise SystemExit
73118

74119

0 commit comments

Comments
 (0)