|
| 1 | +""" |
| 2 | +DataJoint command-line interface. |
| 3 | +
|
| 4 | +Provides a Python REPL with DataJoint pre-loaded and optional schema access. |
| 5 | +
|
| 6 | +Usage:: |
| 7 | +
|
| 8 | + # Start REPL with database credentials |
| 9 | + dj --user root --password secret --host localhost:3306 |
| 10 | +
|
| 11 | + # Load schemas as virtual modules |
| 12 | + dj -s my_lab:lab -s my_analysis:analysis |
| 13 | +
|
| 14 | + # In the REPL |
| 15 | + >>> lab.Subject.to_dicts() |
| 16 | + >>> dj.Diagram(lab.schema) |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
1 | 21 | import argparse |
2 | 22 | from code import interact |
3 | 23 | from collections import ChainMap |
4 | 24 |
|
5 | 25 | import datajoint as dj |
6 | 26 |
|
7 | 27 |
|
8 | | -def cli(args: list = None): |
| 28 | +def cli(args: list[str] | None = None) -> None: |
9 | 29 | """ |
10 | | - Console interface for DataJoint Python |
| 30 | + DataJoint command-line interface. |
| 31 | +
|
| 32 | + Starts an interactive Python REPL with DataJoint imported and configured. |
| 33 | + Optionally loads database schemas as virtual modules for quick exploration. |
| 34 | +
|
| 35 | + Parameters |
| 36 | + ---------- |
| 37 | + args : list[str], optional |
| 38 | + Command-line arguments. If None, reads from sys.argv. |
11 | 39 |
|
12 | | - :param args: List of arguments to be passed in, defaults to reading stdin |
13 | | - :type args: list, optional |
| 40 | + Examples |
| 41 | + -------- |
| 42 | + From the command line:: |
| 43 | +
|
| 44 | + $ dj --host localhost:3306 --user root --password secret |
| 45 | + $ dj -s my_lab:lab -s my_analysis:analysis |
| 46 | +
|
| 47 | + Programmatically:: |
| 48 | +
|
| 49 | + >>> from datajoint.cli import cli |
| 50 | + >>> cli(["--version"]) |
14 | 51 | """ |
15 | 52 | parser = argparse.ArgumentParser( |
16 | | - prog="datajoint", |
17 | | - description="DataJoint console interface.", |
18 | | - conflict_handler="resolve", |
| 53 | + prog="dj", |
| 54 | + description="DataJoint interactive console. Start a Python REPL with DataJoint pre-loaded.", |
| 55 | + epilog="Example: dj -s my_lab:lab --host localhost:3306", |
| 56 | + ) |
| 57 | + parser.add_argument( |
| 58 | + "-V", "--version", |
| 59 | + action="version", |
| 60 | + version=f"{dj.__name__} {dj.__version__}", |
19 | 61 | ) |
20 | | - parser.add_argument("-V", "--version", action="version", version=f"{dj.__name__} {dj.__version__}") |
21 | 62 | parser.add_argument( |
22 | | - "-u", |
23 | | - "--user", |
| 63 | + "-u", "--user", |
24 | 64 | type=str, |
25 | | - default=dj.config["database.user"], |
26 | | - required=False, |
27 | | - help="Datajoint username", |
| 65 | + default=None, |
| 66 | + help="Database username (default: from config)", |
28 | 67 | ) |
29 | 68 | parser.add_argument( |
30 | | - "-p", |
31 | | - "--password", |
| 69 | + "-p", "--password", |
32 | 70 | type=str, |
33 | | - default=dj.config["database.password"], |
34 | | - required=False, |
35 | | - help="Datajoint password", |
| 71 | + default=None, |
| 72 | + help="Database password (default: from config)", |
36 | 73 | ) |
37 | 74 | parser.add_argument( |
38 | | - "-h", |
39 | 75 | "--host", |
40 | 76 | type=str, |
41 | | - default=dj.config["database.host"], |
42 | | - required=False, |
43 | | - help="Datajoint host", |
| 77 | + default=None, |
| 78 | + help="Database host as host:port (default: from config)", |
44 | 79 | ) |
45 | 80 | parser.add_argument( |
46 | | - "-s", |
47 | | - "--schemas", |
| 81 | + "-s", "--schemas", |
48 | 82 | nargs="+", |
49 | 83 | type=str, |
50 | | - required=False, |
51 | | - help="A list of virtual module mappings in `db:schema ...` format", |
| 84 | + metavar="DB:ALIAS", |
| 85 | + help="Load schemas as virtual modules. Format: schema_name:alias", |
52 | 86 | ) |
| 87 | + |
53 | 88 | kwargs = vars(parser.parse_args(args)) |
54 | | - mods = {} |
| 89 | + |
| 90 | + # Apply credentials to config |
55 | 91 | if kwargs["user"]: |
56 | 92 | dj.config["database.user"] = kwargs["user"] |
57 | 93 | if kwargs["password"]: |
58 | 94 | dj.config["database.password"] = kwargs["password"] |
59 | 95 | if kwargs["host"]: |
60 | 96 | dj.config["database.host"] = kwargs["host"] |
| 97 | + |
| 98 | + # Load requested schemas |
| 99 | + mods: dict[str, dj.VirtualModule] = {} |
61 | 100 | if kwargs["schemas"]: |
62 | 101 | for vm in kwargs["schemas"]: |
63 | | - d, m = vm.split(":") |
64 | | - mods[m] = dj.VirtualModule(m, d) |
| 102 | + if ":" not in vm: |
| 103 | + parser.error(f"Invalid schema format '{vm}'. Use schema_name:alias") |
| 104 | + schema_name, alias = vm.split(":", 1) |
| 105 | + mods[alias] = dj.VirtualModule(alias, schema_name) |
65 | 106 |
|
66 | | - banner = "dj repl\n" |
| 107 | + # Build banner |
| 108 | + banner = f"DataJoint {dj.__version__} REPL\n" |
| 109 | + banner += "Type 'dj.' and press Tab for available functions.\n" |
67 | 110 | if mods: |
68 | | - modstr = "\n".join(" - {}".format(m) for m in mods) |
69 | | - banner += "\nschema modules:\n\n" + modstr + "\n" |
70 | | - interact(banner, local=dict(ChainMap(mods, locals(), globals()))) |
| 111 | + banner += "\nLoaded schemas:\n" |
| 112 | + for alias in mods: |
| 113 | + banner += f" {alias} -> {mods[alias].schema.database}\n" |
71 | 114 |
|
| 115 | + # Start interactive session |
| 116 | + interact(banner, local=dict(ChainMap(mods, {"dj": dj}, globals()))) |
72 | 117 | raise SystemExit |
73 | 118 |
|
74 | 119 |
|
|
0 commit comments