forked from apeterswu/RL4NMT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_eten.sh
More file actions
executable file
·36 lines (28 loc) · 917 Bytes
/
train_eten.sh
File metadata and controls
executable file
·36 lines (28 loc) · 917 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
export PYTHONPATH=./:${PYTHONPATH}
export CUDA_VISIBLE_DEVICES=3
binFile=./tensor2tensor/bin
PROBLEM=translate_eten
MODEL=transformer
HPARAMS=eten_transformer_rl_delta_setting
# HPARAMS=zhen_wmt17_transformer_rl_delta_setting_random
# HPARAMS=zhen_wmt17_transformer_rl_total_setting
# HPARAMS=zhen_wmt17_transformer_rl_total_setting_random
# HPARAMS=zhen_wmt17_transformer_rl_delta_setting_random_baseline
# HPARAMS=zhen_wmt17_transformer_rl_delta_setting_random_mle
DATA_DIR=../transformer_data/eten
TRAIN_DIR=./model/${HPARAMS}
mkdir -p $TRAIN_DIR
${binFile}/t2t-trainer \
--t2t_usr_dir=./eten \
--data_dir=$DATA_DIR \
--problems=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--train_steps=300000 \
--save_checkpoints_steps=500 \
--keep_checkpoint_max=50 \
--local_eval_frequency=1000000 \
--hparams='batch_size=64,learning_rate=0.0001' \
--eval_steps=3 \
--worker_gpu=1 \