Skip to content

Commit ae98067

Browse files
committed
Including a script for preparing the pre-trained models.
1 parent 9a65511 commit ae98067

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

pretrain.sh

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/bin/bash
2+
3+
_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )
4+
5+
6+
# This script is for pre-training the models
7+
8+
9+
function process_data() {
10+
local ds=$1; shift
11+
12+
( cd $_DIR/python/
13+
rm -rf $_DIR/output/$ds/data
14+
python -m roosterize.main extract_data_from_corpus\
15+
--corpus=$_DIR/../math-comp-corpus\
16+
--output=$_DIR/output/$ds/data\
17+
--groups=$ds
18+
)
19+
}
20+
21+
function train_model() {
22+
local ds=$1; shift
23+
24+
( cd $_DIR/python/
25+
rm -rf $_DIR/output/$ds/model
26+
python -m roosterize.main train_model\
27+
--train=$_DIR/output/$ds/data/$ds-train\
28+
--val=$_DIR/output/$ds/data/$ds-val\
29+
--model-dir=$_DIR/output/$ds/model\
30+
--output=$_DIR/output/$ds/data\
31+
--config-file=$_DIR/configs/Stmt+ChopKnlTree+attn+copy.json
32+
)
33+
}
34+
35+
function package_model() {
36+
local ds=$1; shift
37+
38+
( cd $_DIR/output/$ds/
39+
tar czf roosterize-model-$ds.tgz model/
40+
)
41+
}
42+
43+
function eval_model() {
44+
local ds=$1; shift
45+
46+
( cd $_DIR/python/
47+
rm -rf $_DIR/output/$ds/results
48+
python -m roosterize.main eval_model\
49+
--data=$_DIR/output/$ds/data/$ds-test\
50+
--model-dir=$_DIR/output/$ds/model\
51+
--output=$_DIR/output/$ds/results
52+
)
53+
}
54+
55+
function retrain_all_models() {
56+
for ds in t1 ta; do
57+
process_data $ds
58+
train_model $ds
59+
package_model $ds
60+
done
61+
}
62+
63+
64+
# ==========
65+
# Main function -- program entry point
66+
# This script can be executed as ./run.sh the_function_to_run
67+
68+
function main() {
69+
local action=${1:?Need Argument}; shift
70+
71+
( cd ${_DIR}
72+
$action "$@"
73+
)
74+
}
75+
76+
main "$@"
77+

0 commit comments

Comments
 (0)