Skip to content

Commit a821e6b

Browse files
authored
Merge pull request #1 from EngineeringSoftware/add-training-data
add training directory defining training sets, update README.md
2 parents bf46fce + 21aa867 commit a821e6b

27 files changed

+2710
-12
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
Roosterize is a tool for suggesting lemma names in verification
44
projects that use the [Coq proof assistant](https://coq.inria.fr).
5+
The tool is based on leveraging neural networks that take serialized Coq
6+
lemma statements and elaborated terms as input; see the [Technique](#Technique)
7+
section below.
58

69
## Requirements
710

@@ -78,7 +81,7 @@ project, and `$SERAPI_OPTIONS` should be replaced with the SerAPI
7881
command line options for mapping logical paths to directories (see [SerAPI's
7982
documentation](https://github.com/ejgallego/coq-serapi/blob/v8.11/FAQ.md#does-serapi-support-coqs-command-line-flags)).
8083
For example, if the logical path (inside Coq) for the project is `Verified`,
81-
you should set `SERAPI_OPTIONS="-Q $PATH_TO_PROJECT,Verified"`.
84+
you should set `SERAPI_OPTIONS="-R $PATH_TO_PROJECT,Verified"`.
8285

8386
The command extracts all lemmas from the project, uses Roosterize's
8487
pre-trained model (at `./models/roosterize-ta`) to predict a lemma name
@@ -90,8 +93,6 @@ Below is an example of printed suggestions:
9093
infotheo/ecc_classic/bch.v: infotheo.ecc_classic.bch.BCH.BCH_PCM_altP1 -> inde_F2
9194
infotheo/ecc_classic/bch.v: infotheo.ecc_classic.bch.BCH.BCH_PCM_altP2 -> inde_mul
9295
infotheo/ecc_classic/bch.v: infotheo.ecc_classic.bch.BCH.PCM_altP -> F2_eq0
93-
infotheo/ecc_classic/bch.v: infotheo.ecc_classic.bch.BCH.PCM_alt_GRS -> P
94-
infotheo/ecc_classic/bch.v: infotheo.ecc_classic.bch.BCH_codebook -> map_P
9596
...
9697
```
9798

@@ -109,7 +110,7 @@ For example, the Coq lemma sentence
109110
```coq
110111
Lemma mg_eq_proof L1 L2 (N1 : mgClassifier L1) : L1 =i L2 -> nerode L2 N1.
111112
```
112-
is serialized into the following tokens:
113+
is serialized into the following tokens (simplified):
113114
```lisp
114115
(Sentence((IDENT Lemma)(IDENT mg_eq_proof)(IDENT L1)(IDENT L2)
115116
(KEYWORD"(")(IDENT N1)(KEYWORD :)(IDENT mgClassifier)
@@ -134,7 +135,8 @@ architecture, as applied to this example:
134135
Our [research paper][arxiv-paper] outlines the design of Roosterize,
135136
and describes an evaluation on a [corpus][math-comp-corpus]
136137
of serialized Coq code derived from the [Mathematical Components][math-comp-website]
137-
family of projects.
138+
family of projects. The training, validation, and testing sets of Coq files from the corpus
139+
used in the evaluation are defined in the `training` directory.
138140

139141
If you have used Roosterize in a research project, please cite
140142
the research paper in any related publication:

roosterize/data/DataMiner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def collect_lemmas(cls, data_mgr: FilesManager, projects: List[Project], files:
456456
# Assign uids
457457
for lemma_i, lemma in enumerate(lemmas): lemma.uid = lemma_i
458458

459-
data_mgr.dump_data([FilesManager.LEMMAS, "lemmas"], lemmas, IOUtils.Format.json, is_batched=True, per_batch=5000)
459+
data_mgr.dump_data([FilesManager.LEMMAS], lemmas, IOUtils.Format.json, is_batched=True, per_batch=5000)
460460
return
461461

462462
@classmethod
@@ -468,7 +468,7 @@ def filter_lemmas(cls, data_mgr: FilesManager):
468468
data_mgr.resolve([FilesManager.LEMMAS_FILTERED]).mkdir(parents=True)
469469

470470
# Load lemmas
471-
lemmas: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
471+
lemmas: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS], IOUtils.Format.json, is_batched=True, clz=Lemma)
472472
heights: List[int] = [l.backend_sexp.height() for l in lemmas]
473473

474474
depth_cutoff_point = sorted(heights)[int(np.ceil(Macros.LEMMAS_DEPTH_CUTOFF * len(lemmas)))]
@@ -480,7 +480,7 @@ def filter_lemmas(cls, data_mgr: FilesManager):
480480
# Assign uids
481481
for lemma_i, lemma in enumerate(lemmas_filtered): lemma.uid = lemma_i
482482

483-
data_mgr.dump_data([FilesManager.LEMMAS_FILTERED, "lemmas"], lemmas_filtered, IOUtils.Format.json, is_batched=True, per_batch=5000)
483+
data_mgr.dump_data([FilesManager.LEMMAS_FILTERED], lemmas_filtered, IOUtils.Format.json, is_batched=True, per_batch=5000)
484484
return
485485

486486
@classmethod
@@ -529,7 +529,7 @@ def collect_lemmas_backend_sexp_transformations(cls, data_mgr: FilesManager):
529529
# Increase recursion limit because the backend sexps are CRAZZZZY deep
530530
sys.setrecursionlimit(10000)
531531

532-
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
532+
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED], IOUtils.Format.json, is_batched=True, clz=Lemma)
533533

534534
# Main stream transformations, applied one after another
535535
levels_lemmas_bsexp_transformed: Dict[str, List[SexpNode]] = dict()
@@ -573,7 +573,7 @@ def collect_lemmas_foreend_sexp_transformations(cls, data_mgr: FilesManager):
573573
# Increase recursion limit because the backend sexps are CRAZZZZY deep
574574
sys.setrecursionlimit(10000)
575575

576-
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
576+
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED], IOUtils.Format.json, is_batched=True, clz=Lemma)
577577

578578
# Main stream transformations, applied one after another
579579
levels_lemmas_fsexp_transformed: Dict[str, List[SexpNode]] = dict()
@@ -857,14 +857,14 @@ def extract_data_from_corpus(cls,
857857
data_mgr = FilesManager(corpus_path)
858858

859859
# 2. Load lemmas and definitions
860-
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
860+
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED], IOUtils.Format.json, is_batched=True, clz=Lemma)
861861
definitions: List[Definition] = data_mgr.load_data([FilesManager.DEFINITIONS, "definitions.json"], IOUtils.Format.json, clz=Definition)
862862

863863
# 3. Output to output_path for each combination of traineval and group
864864
for traineval in trainevals:
865865
for group in groups:
866866
IOUtils.mk_dir(output_path/f"{group}-{traineval}")
867-
data_indexes = data_mgr.load_data([FilesManager.DATA_INDEXES, f"{group}-{traineval}.json"], IOUtils.Format.json, clz=str)
867+
data_indexes = IOUtils.load(project_dir/"training"/f"{group}-{traineval}.json"], IOUtils.Format.json, clz=str)
868868
IOUtils.dump(output_path/f"{group}-{traineval}/lemmas.json", IOUtils.jsonfy([l for l in lemmas_filtered if l.data_index in data_indexes]), IOUtils.Format.json)
869869
IOUtils.dump(output_path/f"{group}-{traineval}/definitions.json", IOUtils.jsonfy([d for d in definitions if d.data_index in data_indexes]), IOUtils.Format.json)
870870
# end for

0 commit comments

Comments
 (0)