Skip to content

Commit 21aa867

Browse files
committed
change directories to read lemmas & training data_indexes from correspondingly
1 parent 762fc1b commit 21aa867

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

roosterize/data/DataMiner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def collect_lemmas(cls, data_mgr: FilesManager, projects: List[Project], files:
456456
# Assign uids
457457
for lemma_i, lemma in enumerate(lemmas): lemma.uid = lemma_i
458458

459-
data_mgr.dump_data([FilesManager.LEMMAS, "lemmas"], lemmas, IOUtils.Format.json, is_batched=True, per_batch=5000)
459+
data_mgr.dump_data([FilesManager.LEMMAS], lemmas, IOUtils.Format.json, is_batched=True, per_batch=5000)
460460
return
461461

462462
@classmethod
@@ -468,7 +468,7 @@ def filter_lemmas(cls, data_mgr: FilesManager):
468468
data_mgr.resolve([FilesManager.LEMMAS_FILTERED]).mkdir(parents=True)
469469

470470
# Load lemmas
471-
lemmas: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
471+
lemmas: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS], IOUtils.Format.json, is_batched=True, clz=Lemma)
472472
heights: List[int] = [l.backend_sexp.height() for l in lemmas]
473473

474474
depth_cutoff_point = sorted(heights)[int(np.ceil(Macros.LEMMAS_DEPTH_CUTOFF * len(lemmas)))]
@@ -480,7 +480,7 @@ def filter_lemmas(cls, data_mgr: FilesManager):
480480
# Assign uids
481481
for lemma_i, lemma in enumerate(lemmas_filtered): lemma.uid = lemma_i
482482

483-
data_mgr.dump_data([FilesManager.LEMMAS_FILTERED, "lemmas"], lemmas_filtered, IOUtils.Format.json, is_batched=True, per_batch=5000)
483+
data_mgr.dump_data([FilesManager.LEMMAS_FILTERED], lemmas_filtered, IOUtils.Format.json, is_batched=True, per_batch=5000)
484484
return
485485

486486
@classmethod
@@ -529,7 +529,7 @@ def collect_lemmas_backend_sexp_transformations(cls, data_mgr: FilesManager):
529529
# Increase recursion limit because the backend sexps are CRAZZZZY deep
530530
sys.setrecursionlimit(10000)
531531

532-
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
532+
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED], IOUtils.Format.json, is_batched=True, clz=Lemma)
533533

534534
# Main stream transformations, applied one after another
535535
levels_lemmas_bsexp_transformed: Dict[str, List[SexpNode]] = dict()
@@ -573,7 +573,7 @@ def collect_lemmas_foreend_sexp_transformations(cls, data_mgr: FilesManager):
573573
# Increase recursion limit because the backend sexps are CRAZZZZY deep
574574
sys.setrecursionlimit(10000)
575575

576-
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
576+
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED], IOUtils.Format.json, is_batched=True, clz=Lemma)
577577

578578
# Main stream transformations, applied one after another
579579
levels_lemmas_fsexp_transformed: Dict[str, List[SexpNode]] = dict()
@@ -857,14 +857,14 @@ def extract_data_from_corpus(cls,
857857
data_mgr = FilesManager(corpus_path)
858858

859859
# 2. Load lemmas and definitions
860-
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED, "lemmas"], IOUtils.Format.json, is_batched=True, clz=Lemma)
860+
lemmas_filtered: List[Lemma] = data_mgr.load_data([FilesManager.LEMMAS_FILTERED], IOUtils.Format.json, is_batched=True, clz=Lemma)
861861
definitions: List[Definition] = data_mgr.load_data([FilesManager.DEFINITIONS, "definitions.json"], IOUtils.Format.json, clz=Definition)
862862

863863
# 3. Output to output_path for each combination of traineval and group
864864
for traineval in trainevals:
865865
for group in groups:
866866
IOUtils.mk_dir(output_path/f"{group}-{traineval}")
867-
data_indexes = data_mgr.load_data([FilesManager.DATA_INDEXES, f"{group}-{traineval}.json"], IOUtils.Format.json, clz=str)
867+
data_indexes = IOUtils.load(project_dir/"training"/f"{group}-{traineval}.json"], IOUtils.Format.json, clz=str)
868868
IOUtils.dump(output_path/f"{group}-{traineval}/lemmas.json", IOUtils.jsonfy([l for l in lemmas_filtered if l.data_index in data_indexes]), IOUtils.Format.json)
869869
IOUtils.dump(output_path/f"{group}-{traineval}/definitions.json", IOUtils.jsonfy([d for d in definitions if d.data_index in data_indexes]), IOUtils.Format.json)
870870
# end for

0 commit comments

Comments
 (0)