-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Cli update #1666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Cli update #1666
Changes from all commits
f764627
055601f
f3f1872
995cfcc
89decfd
4edf52f
cba4e3f
4d49cff
cc72cec
d7f91ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
from itertools import islice | ||
from logging import getLogger | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
from typing import Any, Callable, Optional, Union | ||
|
||
from deeppavlov.core.commands.utils import import_packages, parse_config | ||
from deeppavlov.core.common.chainer import Chainer | ||
|
@@ -28,8 +28,13 @@ | |
log = getLogger(__name__) | ||
|
||
|
||
def build_model(config: Union[str, Path, dict], mode: str = 'infer', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert unnecessary style changes here and below. If using black (I guess), we should use it on whole code base. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UPD: make black line lenght 120, and commit black config |
||
load_trained: bool = False, install: bool = False, download: bool = False) -> Chainer: | ||
def build_model( | ||
config: Union[str, Path, dict], | ||
mode: str = "infer", | ||
load_trained: bool = False, | ||
install: bool = False, | ||
download: bool = False, | ||
) -> Chainer: | ||
"""Build and return the model described in corresponding configuration file.""" | ||
config = parse_config(config) | ||
|
||
|
@@ -38,66 +43,94 @@ def build_model(config: Union[str, Path, dict], mode: str = 'infer', | |
if download: | ||
deep_download(config) | ||
|
||
import_packages(config.get('metadata', {}).get('imports', [])) | ||
import_packages(config.get("metadata", {}).get("imports", [])) | ||
|
||
model_config = config['chainer'] | ||
model_config = config["chainer"] | ||
|
||
model = Chainer(model_config['in'], model_config['out'], model_config.get('in_y')) | ||
model = Chainer(model_config["in"], model_config["out"], model_config.get("in_y")) | ||
|
||
for component_config in model_config['pipe']: | ||
if load_trained and ('fit_on' in component_config or 'in_y' in component_config): | ||
for component_config in model_config["pipe"]: | ||
if load_trained and ("fit_on" in component_config or "in_y" in component_config): | ||
try: | ||
component_config['load_path'] = component_config['save_path'] | ||
component_config["load_path"] = component_config["save_path"] | ||
except KeyError: | ||
log.warning('No "save_path" parameter for the {} component, so "load_path" will not be renewed' | ||
.format(component_config.get('class_name', component_config.get('ref', 'UNKNOWN')))) | ||
log.warning( | ||
'No "save_path" parameter for the {} component, so "load_path" will not be renewed'.format( | ||
component_config.get("class_name", component_config.get("ref", "UNKNOWN")) | ||
) | ||
) | ||
|
||
component = from_params(component_config, mode=mode) | ||
|
||
if 'id' in component_config: | ||
model._components_dict[component_config['id']] = component | ||
if "id" in component_config: | ||
model._components_dict[component_config["id"]] = component | ||
|
||
if 'in' in component_config: | ||
c_in = component_config['in'] | ||
c_out = component_config['out'] | ||
in_y = component_config.get('in_y', None) | ||
main = component_config.get('main', False) | ||
if "in" in component_config: | ||
c_in = component_config["in"] | ||
c_out = component_config["out"] | ||
in_y = component_config.get("in_y", None) | ||
main = component_config.get("main", False) | ||
model.append(component, c_in, c_out, in_y, main) | ||
|
||
return model | ||
|
||
|
||
def end_repl_mode(function: Callable[..., Any]) -> Callable[..., Any]: | ||
"""Decorator for processing Ctrl-C, Ctrl-D pressing.""" | ||
|
||
def wrapper(*args: Any, **kwargs: Any): | ||
try: | ||
return function(*args, **kwargs) | ||
except (KeyboardInterrupt, EOFError): | ||
print("\nExit.") | ||
sys.exit(0) | ||
|
||
return wrapper | ||
|
||
|
||
def preparing_arguments(model): | ||
"""Prepares arguments.""" | ||
arguments = [] | ||
for in_x in model: | ||
data: str = input(f"\033[34m\033[107m{in_x}:\033[0m ") | ||
if data.strip() == "q": | ||
print("\nExit.") | ||
sys.exit(0) | ||
arguments.append((data,)) | ||
return arguments | ||
|
||
|
||
@end_repl_mode | ||
def interact_model(config: Union[str, Path, dict]) -> None: | ||
"""Start interaction with the model described in corresponding configuration file.""" | ||
model = build_model(config) | ||
|
||
print("\nExit - type q and press Enter, or press Ctrl-C, or Ctrl-D.") | ||
|
||
while True: | ||
args = [] | ||
for in_x in model.in_x: | ||
args.append((input('{}::'.format(in_x)),)) | ||
# check for exit command | ||
if args[-1][0] in {'exit', 'stop', 'quit', 'q'}: | ||
return | ||
arguments = preparing_arguments(model.in_x) | ||
|
||
pred = model(*args) | ||
pred = model(*arguments) | ||
if len(model.out_params) > 1: | ||
pred = zip(*pred) | ||
|
||
print('>>', *pred) | ||
print(">> ", *pred) | ||
|
||
|
||
def predict_on_stream(config: Union[str, Path, dict], | ||
batch_size: Optional[int] = None, | ||
file_path: Optional[str] = None) -> None: | ||
def predict_on_stream( | ||
config: Union[str, Path, dict], | ||
batch_size: Optional[int] = None, | ||
file_path: Optional[str] = None, | ||
) -> None: | ||
"""Make a prediction with the component described in corresponding configuration file.""" | ||
|
||
batch_size = batch_size or 1 | ||
if file_path is None or file_path == '-': | ||
if file_path is None or file_path == "-": | ||
if sys.stdin.isatty(): | ||
raise RuntimeError('To process data from terminal please use interact mode') | ||
raise RuntimeError("To process data from terminal please use interact mode") | ||
f = sys.stdin | ||
else: | ||
f = open(file_path, encoding='utf8') | ||
f = open(file_path, encoding="utf8") | ||
|
||
model: Chainer = build_model(config) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[tool.black] | ||
line-length = 120 | ||
target-version = ["py36", "py37", "py38", "py39", "py310"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import pytest | ||
from deeppavlov.core.commands.infer import end_repl_mode, interact_model, preparing_arguments | ||
|
||
|
||
def test_end_repl_mode_decorator_keyboard_interrupt(): | ||
"""Check Ctrl-C.""" | ||
|
||
def error_keyboard_interrupt(): | ||
raise KeyboardInterrupt | ||
|
||
with pytest.raises(SystemExit) as ex: | ||
function = end_repl_mode(error_keyboard_interrupt) | ||
function() | ||
assert ex.value.code == 0 | ||
|
||
|
||
def test_end_repl_mode_decorator_eoferror(): | ||
"""Check Ctrl-D.""" | ||
|
||
def error_eoferror(): | ||
raise EOFError | ||
|
||
with pytest.raises(SystemExit) as ex: | ||
function = end_repl_mode(error_eoferror) | ||
function() | ||
assert ex.value.code == 0 | ||
|
||
|
||
def test_preparing_arguments(monkeypatch): | ||
"""Check format arguments.""" | ||
|
||
def input_data(data: str): | ||
return "data" | ||
|
||
monkeypatch.setattr("builtins.input", input_data) | ||
assert preparing_arguments([1, 2]) == [("data",), ("data",)] | ||
|
||
|
||
def test_preparing_arguments_exit(monkeypatch): | ||
"""Check exit by `q`""" | ||
|
||
def input_data(data: str): | ||
return "q" | ||
|
||
monkeypatch.setattr("builtins.input", input_data) | ||
with pytest.raises(SystemExit) as ex: | ||
preparing_arguments([1, 2]) | ||
assert ex.value.code == 0 |
Uh oh!
There was an error while loading. Please reload this page.