|
| 1 | +import argparse |
| 2 | +import sys |
| 3 | +from code import interact |
| 4 | +from collections import ChainMap |
| 5 | +from datajoint import __version__ as version, config, create_virtual_module |
| 6 | + |
| 7 | + |
| 8 | +def dj_cli(args: list = None): |
| 9 | + """ |
| 10 | + Console interface for DataJoint Python |
| 11 | +
|
| 12 | + :param args: List of arguments to be passed in, defaults to reading stdin |
| 13 | + :type args: list, optional |
| 14 | + """ |
| 15 | + parser = argparse.ArgumentParser( |
| 16 | + prog="datajoint", |
| 17 | + description="DataJoint console interface.", |
| 18 | + conflict_handler="resolve", |
| 19 | + ) |
| 20 | + parser.add_argument( |
| 21 | + "-V", "--version", action="version", version=f"datajoint {version}" |
| 22 | + ) |
| 23 | + parser.add_argument( |
| 24 | + "-u", |
| 25 | + "--user", |
| 26 | + type=str, |
| 27 | + default=config["database.user"], |
| 28 | + required=False, |
| 29 | + help="Datajoint username", |
| 30 | + ) |
| 31 | + parser.add_argument( |
| 32 | + "-p", |
| 33 | + "--password", |
| 34 | + type=str, |
| 35 | + default=config["database.password"], |
| 36 | + required=False, |
| 37 | + help="Datajoint password", |
| 38 | + ) |
| 39 | + parser.add_argument( |
| 40 | + "-h", |
| 41 | + "--host", |
| 42 | + type=str, |
| 43 | + default=config["database.host"], |
| 44 | + required=False, |
| 45 | + help="Datajoint host", |
| 46 | + ) |
| 47 | + parser.add_argument( |
| 48 | + "-s", |
| 49 | + "--schemas", |
| 50 | + nargs="+", |
| 51 | + type=[str], |
| 52 | + default=[], |
| 53 | + required=False, |
| 54 | + help="A list of virtual module mappings in `db:schema ...` format", |
| 55 | + ) |
| 56 | + kwargs = vars(parser.parse_args(args if sys.argv[1:] else ["--help"])) |
| 57 | + mods = {} |
| 58 | + if kwargs["user"]: |
| 59 | + config["database.user"] = kwargs["user"] |
| 60 | + if kwargs["password"]: |
| 61 | + config["database.password"] = kwargs["password"] |
| 62 | + if kwargs["host"]: |
| 63 | + config["database.host"] = kwargs["host"] |
| 64 | + if kwargs["schemas"]: |
| 65 | + d, m = kwargs["schemas"].split(":") |
| 66 | + mods[m] = create_virtual_module(m, d) |
| 67 | + |
| 68 | + banner = "dj repl\n" |
| 69 | + if mods: |
| 70 | + modstr = "\n".join(" - {}".format(m) for m in mods) |
| 71 | + banner += "\nschema modules:\n\n" + modstr + "\n" |
| 72 | + interact(banner, local=dict(ChainMap(mods, locals(), globals()))) |
| 73 | + |
| 74 | + raise SystemExit |
| 75 | + |
| 76 | + |
| 77 | +if __name__ == "__main__": |
| 78 | + dj_cli() |
0 commit comments