Skip to content

Commit 4ff1ee4

Browse files
committed
Upgrading to 1.1.0 version for Coq v8.10.0.
Major changes in this version: - Add a new command line interface. - Add a new VSCode interface with a LSP server. - In the aforementioned two interfaces, most arguments can be supplied via global or project-local .roosterizerc files. - Drop the redundant OpenNMTInterfaceForNaming model, and simplify model selecting after only having the multi-source model. - Add the script to package as binary distributions using PyInstaller.
1 parent bde2977 commit 4ff1ee4

25 files changed

+1118
-1492
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ MANIFEST
112112
# Usually these files are written by a python script from a template
113113
# before PyInstaller builds the exe, so as to inject date/other infos into it.
114114
*.manifest
115-
*.spec
115+
#*.spec
116116

117117
# Installer logs
118118
pip-log.txt

Makefile

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
INSTALL_DIR = $(HOME)/opt
2+
3+
PY_SRCS = preprocess.py train.py translate.py $(wildcard roosterize/*.py) $(wildcard onmt/*.py)
4+
5+
6+
all: package
7+
8+
.PHONY: package
9+
package: dist/roosterize/roosterize dist/roosterize.tgz
10+
11+
dist/roosterize/roosterize: roosterize.spec $(PY_SRCS)
12+
pyinstaller roosterize.spec -y --log-level WARN
13+
14+
dist/roosterize.tgz: dist/roosterize/roosterize
15+
cd dist && tar czf roosterize.tgz roosterize/
16+
17+
.PHONY: install
18+
install: package
19+
rm -rf $(INSTALL_DIR)/roosterize
20+
mkdir -p $(INSTALL_DIR)/roosterize
21+
cp -r dist/roosterize $(INSTALL_DIR)/roosterize/bin
22+
23+
.PHONY: clean
24+
clean:
25+
-rm -rf dist/ build/

onmt/utils/logging.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,6 @@
1-
# -*- coding: utf-8 -*-
2-
from __future__ import absolute_import
1+
from seutil import LoggingUtils
32

4-
import logging
3+
logger = LoggingUtils.get_logger(__name__)
54

6-
logger = logging.getLogger()
7-
8-
9-
def init_logger(log_file=None, log_file_level=logging.NOTSET):
10-
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
11-
logger = logging.getLogger()
12-
logger.setLevel(logging.INFO)
13-
14-
console_handler = logging.StreamHandler()
15-
console_handler.setFormatter(log_format)
16-
logger.handlers = [console_handler]
17-
18-
if log_file and log_file != '':
19-
file_handler = logging.FileHandler(log_file)
20-
file_handler.setLevel(log_file_level)
21-
file_handler.setFormatter(log_format)
22-
logger.addHandler(file_handler)
23-
24-
return logger
5+
def init_logger(log_file=None, log_file_level=None):
6+
return logger

requirements.txt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
configargparse
2-
nltk
1+
configargparse~=1.2.3
32
future
4-
seutil==0.4.12
3+
nltk~=3.5
4+
numpy~=1.19.2
5+
pygls~=0.9.1
6+
seutil>=0.5.4
57
six
6-
torchtext==0.4.0
7-
tqdm==4.30.*
8+
torch==1.1.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
9+
torchtext==0.4.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
10+
tqdm~=4.30.0

roosterize.spec

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# -*- mode: python ; coding: utf-8 -*-
2+
3+
block_cipher = None
4+
5+
6+
a = Analysis(['roosterize/main.py'],
7+
pathex=['/home/pynie/projects/roosterize-vscode/python'],
8+
binaries=[],
9+
datas=[],
10+
hiddenimports=[],
11+
hookspath=[],
12+
runtime_hooks=[],
13+
excludes=[],
14+
win_no_prefer_redirects=False,
15+
win_private_assemblies=False,
16+
cipher=block_cipher,
17+
noarchive=False)
18+
pyz = PYZ(a.pure, a.zipped_data,
19+
cipher=block_cipher)
20+
exe = EXE(pyz,
21+
a.scripts,
22+
[],
23+
exclude_binaries=True,
24+
name='roosterize',
25+
debug=False,
26+
bootloader_ignore_signals=False,
27+
strip=False,
28+
upx=True,
29+
console=True )
30+
coll = COLLECT(exe,
31+
a.binaries,
32+
a.zipfiles,
33+
a.datas,
34+
strip=False,
35+
upx=True,
36+
upx_exclude=[],
37+
name='roosterize')

roosterize/FilesManager.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
from typing import *
2-
31
import math
4-
from pathlib import Path
52
import traceback
6-
from tqdm import tqdm
3+
from pathlib import Path
4+
from typing import Any, Callable, Iterator, List, NoReturn, Optional, Union
75

86
from seutil import IOUtils, LoggingUtils
7+
from tqdm import tqdm
8+
9+
logger = LoggingUtils.get_logger(__name__)
910

1011

1112
class FilesManager:
1213
"""
1314
Handles the loading/dumping of files in a dataset.
1415
"""
15-
logger = LoggingUtils.get_logger(__name__)
1616

1717
ALL_LEMMAS_BACKEND_SEXP_TRANSFORMATIONS = "all-lemmas-bsexp-transformations"
1818
ALL_LEMMAS_FOREEND_SEXP_TRANSFORMATIONS = "all-lemmas-fsexp-transformations"
@@ -34,7 +34,7 @@ def __init__(self, data_dir: Path):
3434
def clean_path(self, rel_path: Union[str, List[str]]):
3535
abs_path = self.data_dir / self.assemble_rel_path(rel_path)
3636
if abs_path.exists():
37-
self.logger.info(f"Removing existing things at {abs_path}")
37+
logger.info(f"Removing existing things at {abs_path}")
3838
IOUtils.rm(abs_path)
3939
# end if
4040
return
@@ -43,7 +43,8 @@ def clean_path(self, rel_path: Union[str, List[str]]):
4343
def is_json_format(cls, fmt: IOUtils.Format) -> bool:
4444
return fmt in [IOUtils.Format.json, IOUtils.Format.jsonPretty, IOUtils.Format.jsonNoSort]
4545

46-
def dump_data(self,
46+
def dump_data(
47+
self,
4748
rel_path: Union[str, List[str]],
4849
data: Any,
4950
fmt: IOUtils.Format,
@@ -53,50 +54,43 @@ def dump_data(self,
5354
):
5455
abs_path = self.data_dir / self.assemble_rel_path(rel_path)
5556
if abs_path.exists() and not exist_ok:
56-
LoggingUtils.log_and_raise(self.logger, f"Cannot rewrite existing data at {abs_path}", IOError)
57-
# end if
57+
raise IOError(f"Cannot rewrite existing data at {abs_path}")
5858

5959
abs_path.parent.mkdir(parents=True, exist_ok=True)
6060
if not is_batched:
6161
if self.is_json_format(fmt):
6262
data = IOUtils.jsonfy(data)
63-
# end if
6463
IOUtils.dump(abs_path, data, fmt)
6564
else:
6665
# In batched mode, the data need to be slice-able and sizable
6766
IOUtils.rm(abs_path)
6867
abs_path.mkdir(parents=True)
6968

70-
for batch_i in tqdm(range(math.ceil(len(data)/per_batch))):
71-
data_batch = data[per_batch*batch_i : per_batch*(batch_i+1)]
69+
for batch_i in tqdm(range(math.ceil(len(data) / per_batch))):
70+
data_batch = data[per_batch * batch_i: per_batch * (batch_i + 1)]
7271
if self.is_json_format(fmt):
7372
data_batch = IOUtils.jsonfy(data_batch)
74-
# end if
75-
IOUtils.dump(abs_path/f"batch-{batch_i}.{fmt.get_extension()}", data_batch, fmt)
76-
# end for
77-
# end if
73+
IOUtils.dump(abs_path / f"batch-{batch_i}.{fmt.get_extension()}", data_batch, fmt)
7874
return
7975

80-
def load_data(self,
76+
def load_data(
77+
self,
8178
rel_path: Union[str, List[str]],
8279
fmt: IOUtils.Format,
8380
is_batched: bool = False,
84-
clz = None,
81+
clz=None,
8582
) -> Any:
8683
if self.is_json_format(fmt) and clz is None:
87-
self.logger.warning(f"Load data from {rel_path} with json format, but did not specify clz (at {traceback.format_stack()})")
88-
# end if
84+
logger.warning(f"Load data from {rel_path} with json format, but did not specify clz (at {traceback.format_stack()})")
8985

9086
abs_path = self.data_dir / self.assemble_rel_path(rel_path)
9187
if not abs_path.exists():
92-
LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
93-
# end if
88+
raise IOError(f"Cannot find data at {abs_path}")
9489

9590
if not is_batched:
9691
data = IOUtils.load(abs_path, fmt)
9792
if self.is_json_format(fmt) and clz is not None:
9893
data = IOUtils.dejsonfy(data, clz)
99-
# end if
10094
return data
10195
else:
10296
data = list()
@@ -106,25 +100,21 @@ def load_data(self,
106100
data_batch = IOUtils.load(batch_file, fmt)
107101
if self.is_json_format(fmt) and clz is not None:
108102
data_batch = IOUtils.dejsonfy(data_batch, clz)
109-
# end if
110103
data.extend(data_batch)
111-
# end for
112104
return data
113-
# end if
114105

115-
def iter_batched_data(self,
106+
def iter_batched_data(
107+
self,
116108
rel_path: Union[str, List[str]],
117109
fmt: IOUtils.Format,
118-
clz = None,
110+
clz=None,
119111
) -> Iterator:
120112
if self.is_json_format(fmt) and clz is None:
121-
self.logger.warning(f"Load data from {rel_path} with json format, but did not specify clz")
122-
# end if
113+
logger.warning(f"Load data from {rel_path} with json format, but did not specify clz")
123114

124115
abs_path = self.data_dir / self.assemble_rel_path(rel_path)
125116
if not abs_path.exists():
126-
LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
127-
# end if
117+
raise IOError(f"Cannot find data at {abs_path}")
128118

129119
batch_numbers = sorted([int(str(f.stem).split("-")[1]) for f in abs_path.iterdir()])
130120
for batch_number in batch_numbers:
@@ -134,10 +124,12 @@ def iter_batched_data(self,
134124
data_entry = IOUtils.dejsonfy(data_entry, clz)
135125
# end if
136126
yield data_entry
137-
# end for
138-
# end for
139127

140-
def dump_ckpt(self, rel_path: Union[str, List[str]], obj: Any, ckpt_id: int,
128+
def dump_ckpt(
129+
self,
130+
rel_path: Union[str, List[str]],
131+
obj: Any,
132+
ckpt_id: int,
141133
dump_func: Callable[[Any, str], NoReturn],
142134
ckpt_keep_max: int = 5,
143135
) -> NoReturn:
@@ -152,25 +144,23 @@ def dump_ckpt(self, rel_path: Union[str, List[str]], obj: Any, ckpt_id: int,
152144
ckpt_ids = [int(str(f.name)) for f in abs_path.iterdir()]
153145
for ckpt_id in sorted(ckpt_ids)[:-ckpt_keep_max]:
154146
IOUtils.rm(abs_path / str(ckpt_id))
155-
# end for
156-
# end if
157147
return
158148

159-
def load_ckpt(self, rel_path: Union[str, List[str]],
149+
def load_ckpt(
150+
self,
151+
rel_path: Union[str, List[str]],
160152
load_func: Callable[[str], Any],
161153
ckpt_id: Optional[int] = None,
162154
) -> Any:
163155
abs_path = self.data_dir / self.assemble_rel_path(rel_path)
164156
if not abs_path.exists():
165-
LoggingUtils.log_and_raise(self.logger, f"Cannot find data at {abs_path}", IOError)
166-
# end if
157+
raise IOError(f"Cannot find data at {abs_path}")
167158

168159
if ckpt_id is None:
169160
# Find the latest ckpt
170161
ckpt_ids = [int(str(f.name)) for f in abs_path.iterdir()]
171162
ckpt_id = max(ckpt_ids)
172-
self.logger.info(f"Loading the latest checkpoint {ckpt_id} at {abs_path}")
173-
# end if
163+
logger.info(f"Loading the latest checkpoint {ckpt_id} at {abs_path}")
174164

175165
return load_func(str(abs_path / str(ckpt_id)))
176166

@@ -181,5 +171,4 @@ def resolve(self, rel_path: Union[str, List[str]]) -> Path:
181171
def assemble_rel_path(cls, rel_path: Union[str, List[str]]) -> str:
182172
if not isinstance(rel_path, str):
183173
rel_path = "/".join(rel_path)
184-
# end if
185174
return rel_path

roosterize/Macros.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import *
2-
31
import os
42
from pathlib import Path
53

roosterize/Utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,3 @@ def counter_most_common_to_pretty_yaml(cls, most_common: List[Tuple[Any, int]])
108108
# end for
109109
s += "]\n"
110110
return s
111-
112-
@classmethod
113-
def modify_and_import(cls, module_name, package, modification_func):
114-
spec = importlib.util.find_spec(module_name, package)
115-
source = spec.loader.get_source(module_name)
116-
new_source = modification_func(source)
117-
module = importlib.util.module_from_spec(spec)
118-
codeobj = compile(new_source, module.__spec__.origin, 'exec')
119-
exec(codeobj, module.__dict__)
120-
sys.modules[module_name] = module
121-
return module

roosterize/data/DataMiner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,8 @@ def collect_lemmas_foreend_sexp_transformations(cls, data_mgr: FilesManager):
616616
VTYPES_DEFINITIONS = [SexpInfo.VernacConsts.type_definition]
617617

618618
@classmethod
619-
def collect_lemmas_doc(cls,
619+
def collect_lemmas_doc(
620+
cls,
620621
doc: CoqDocument,
621622
ast_sexp_list: List[SexpNode],
622623
serapi_options: str,

roosterize/data/ModelSpec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __hash__(self):
2626
def build_from_dict(cls, d: dict) -> "ModelSpec":
2727
model_spec = ModelSpec(
2828
name=d.get("name", ""),
29-
model=d.get("model"),
29+
model=d.get("model", "MultiSourceSeq2Seq"),
3030
config_file=d.get("config-file") if "config-file" in d else None,
3131
config_dict=dict(),
3232
)

0 commit comments

Comments
 (0)