diff --git a/.gitignore b/.gitignore index c021e7d41d..0ca16b47aa 100644 --- a/.gitignore +++ b/.gitignore @@ -73,6 +73,9 @@ target/ # pyenv .python-version +# asdf +.tool-versions + # celery beat schedule file celerybeat-schedule diff --git a/deeppavlov/core/commands/infer.py b/deeppavlov/core/commands/infer.py index 5a6ed65023..46f4094f82 100644 --- a/deeppavlov/core/commands/infer.py +++ b/deeppavlov/core/commands/infer.py @@ -16,7 +16,7 @@ from itertools import islice from logging import getLogger from pathlib import Path -from typing import Optional, Union +from typing import Any, Callable, Optional, Union from deeppavlov.core.commands.utils import import_packages, parse_config from deeppavlov.core.common.chainer import Chainer @@ -28,8 +28,13 @@ log = getLogger(__name__) -def build_model(config: Union[str, Path, dict], mode: str = 'infer', - load_trained: bool = False, install: bool = False, download: bool = False) -> Chainer: +def build_model( + config: Union[str, Path, dict], + mode: str = "infer", + load_trained: bool = False, + install: bool = False, + download: bool = False, +) -> Chainer: """Build and return the model described in corresponding configuration file.""" config = parse_config(config) @@ -38,66 +43,94 @@ def build_model(config: Union[str, Path, dict], mode: str = 'infer', if download: deep_download(config) - import_packages(config.get('metadata', {}).get('imports', [])) + import_packages(config.get("metadata", {}).get("imports", [])) - model_config = config['chainer'] + model_config = config["chainer"] - model = Chainer(model_config['in'], model_config['out'], model_config.get('in_y')) + model = Chainer(model_config["in"], model_config["out"], model_config.get("in_y")) - for component_config in model_config['pipe']: - if load_trained and ('fit_on' in component_config or 'in_y' in component_config): + for component_config in model_config["pipe"]: + if load_trained and ("fit_on" in component_config or "in_y" in component_config): try: - component_config['load_path'] = component_config['save_path'] + component_config["load_path"] = component_config["save_path"] except KeyError: - log.warning('No "save_path" parameter for the {} component, so "load_path" will not be renewed' - .format(component_config.get('class_name', component_config.get('ref', 'UNKNOWN')))) + log.warning( + 'No "save_path" parameter for the {} component, so "load_path" will not be renewed'.format( + component_config.get("class_name", component_config.get("ref", "UNKNOWN")) + ) + ) component = from_params(component_config, mode=mode) - if 'id' in component_config: - model._components_dict[component_config['id']] = component + if "id" in component_config: + model._components_dict[component_config["id"]] = component - if 'in' in component_config: - c_in = component_config['in'] - c_out = component_config['out'] - in_y = component_config.get('in_y', None) - main = component_config.get('main', False) + if "in" in component_config: + c_in = component_config["in"] + c_out = component_config["out"] + in_y = component_config.get("in_y", None) + main = component_config.get("main", False) model.append(component, c_in, c_out, in_y, main) return model +def end_repl_mode(function: Callable[..., Any]) -> Callable[..., Any]: + """Decorator for processing Ctrl-C, Ctrl-D pressing.""" + + def wrapper(*args: Any, **kwargs: Any): + try: + return function(*args, **kwargs) + except (KeyboardInterrupt, EOFError): + print("\nExit.") + sys.exit(0) + + return wrapper + + +def preparing_arguments(model): + """Prepares arguments.""" + arguments = [] + for in_x in model: + data: str = input(f"\033[34m\033[107m{in_x}:\033[0m ") + if data.strip() == "q": + print("\nExit.") + sys.exit(0) + arguments.append((data,)) + return arguments + + +@end_repl_mode def interact_model(config: Union[str, Path, dict]) -> None: """Start interaction with the model described in corresponding configuration file.""" model = build_model(config) + print("\nExit - type q and press Enter, or press Ctrl-C, or Ctrl-D.") + while True: - args = [] - for in_x in model.in_x: - args.append((input('{}::'.format(in_x)),)) - # check for exit command - if args[-1][0] in {'exit', 'stop', 'quit', 'q'}: - return + arguments = preparing_arguments(model.in_x) - pred = model(*args) + pred = model(*arguments) if len(model.out_params) > 1: pred = zip(*pred) - print('>>', *pred) + print(">> ", *pred) -def predict_on_stream(config: Union[str, Path, dict], - batch_size: Optional[int] = None, - file_path: Optional[str] = None) -> None: +def predict_on_stream( + config: Union[str, Path, dict], + batch_size: Optional[int] = None, + file_path: Optional[str] = None, +) -> None: """Make a prediction with the component described in corresponding configuration file.""" batch_size = batch_size or 1 - if file_path is None or file_path == '-': + if file_path is None or file_path == "-": if sys.stdin.isatty(): - raise RuntimeError('To process data from terminal please use interact mode') + raise RuntimeError("To process data from terminal please use interact mode") f = sys.stdin else: - f = open(file_path, encoding='utf8') + f = open(file_path, encoding="utf8") model: Chainer = build_model(config) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..6b90104c10 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 120 +target-version = ["py36", "py37", "py38", "py39", "py310"] diff --git a/tests/test_core_commands/test_infer.py b/tests/test_core_commands/test_infer.py new file mode 100644 index 0000000000..fe1d9cd507 --- /dev/null +++ b/tests/test_core_commands/test_infer.py @@ -0,0 +1,48 @@ +import pytest +from deeppavlov.core.commands.infer import end_repl_mode, interact_model, preparing_arguments + + +def test_end_repl_mode_decorator_keyboard_interrupt(): + """Check Ctrl-C.""" + + def error_keyboard_interrupt(): + raise KeyboardInterrupt + + with pytest.raises(SystemExit) as ex: + function = end_repl_mode(error_keyboard_interrupt) + function() + assert ex.value.code == 0 + + +def test_end_repl_mode_decorator_eoferror(): + """Check Ctrl-D.""" + + def error_eoferror(): + raise EOFError + + with pytest.raises(SystemExit) as ex: + function = end_repl_mode(error_eoferror) + function() + assert ex.value.code == 0 + + +def test_preparing_arguments(monkeypatch): + """Check format arguments.""" + + def input_data(data: str): + return "data" + + monkeypatch.setattr("builtins.input", input_data) + assert preparing_arguments([1, 2]) == [("data",), ("data",)] + + +def test_preparing_arguments_exit(monkeypatch): + """Check exit by `q`""" + + def input_data(data: str): + return "q" + + monkeypatch.setattr("builtins.input", input_data) + with pytest.raises(SystemExit) as ex: + preparing_arguments([1, 2]) + assert ex.value.code == 0