Skip to content

Commit 8942ca9

Browse files
committed
fix bug, add relateve path
1 parent c55bb4d commit 8942ca9

File tree

10 files changed

+21
-13
lines changed

10 files changed

+21
-13
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pip3 install vllm==0.8.3
1616
1717
# Install flash-attn
1818
pip3 install flash-attn --no-build-isolation
19+
pip3 install tensorboard
1920
```
2021

2122
2. Prepare environment for ALFWorld
@@ -25,7 +26,9 @@ conda activate alfworld
2526
2627
# download task for training
2728
pip install alfworld
28-
alfworld-download
29+
pip install fastapi
30+
pip install uvicorn
31+
alfworld-download --data-dir ./get_data/alfworld
2932
```
3033

3134
3. Prepare environment for ScienceWorld
@@ -34,12 +37,15 @@ conda create --name scienceworld python=3.8
3437
conda activate scienceworld
3538
3639
pip install scienceworld
40+
pip install fastapi
41+
pip install uvicorn
3742
```
3843

3944
## 2. Prepare for data
4045
```
4146
# get task data for rl training
42-
bash get_data/get_data_for_training.sh
47+
cd get_data
48+
bash get_data_for_training.sh
4349
```
4450

4551
## 3. Start training

cmd/alf.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ if ss -tuln | grep -q ":$PORT "; then
3232
echo "端口 $PORT 已被占用"
3333
else
3434
echo "$PORT 未被占用"
35-
conda activate /path/to/alfworld-env
35+
conda activate alfworld
3636
cd $REPO_HOME/verl/alfworld_server/server
3737
server_cmd="python start_server.py --num_servers 8"
3838

@@ -43,7 +43,7 @@ else
4343
fi
4444

4545
cd $REPO_HOME
46-
conda activate /path/to/embodied-r1-env
46+
conda activate embodied-r1
4747
cmd="bash ${bash_path}"
4848
echo "Running $cmd"
4949

cmd/sci_easy.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ if ss -tuln | grep -q ":$PORT "; then
3232
echo "端口 $PORT 已被占用"
3333
else
3434
echo "$PORT 未被占用"
35-
conda activate /path/to/sciworld-env
35+
conda activate scienceworld
3636
cd $REPO_HOME/verl/scienceworld_server
3737
server_cmd="python start_server.py --num_servers 8"
3838

@@ -43,7 +43,7 @@ else
4343
fi
4444

4545
cd $REPO_HOME
46-
conda activate /path/to/embodied-r1-env
46+
conda activate embodied-r1
4747
cmd="bash ${bash_path}"
4848
echo "Running $cmd"
4949

cmd/sci_nornal.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ if ss -tuln | grep -q ":$PORT "; then
3232
echo "端口 $PORT 已被占用"
3333
else
3434
echo "$PORT 未被占用"
35-
conda activate /path/to/sciworld-env
35+
conda activate scienceworld
3636
cd $REPO_HOME/verl/scienceworld_server
3737
server_cmd="python start_server.py --num_servers 8"
3838

@@ -43,7 +43,7 @@ else
4343
fi
4444

4545
cd $REPO_HOME
46-
conda activate /path/to/embodied-r1-env
46+
conda activate embodied-r1
4747
cmd="bash ${bash_path}"
4848
echo "Running $cmd"
4949

examples/grpo_trainer/alf.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ray start --head
99
python -m verl.trainer.main_ppo_alf \
1010
algorithm.adv_estimator=grpo \
1111
data.train_files=get_data/rl/alf_train.json \
12-
data.val_files=get_data/rl/alf_seen.json \
12+
data.val_files=get_data/rl/alf_valid_seen.json \
1313
data.train_batch_size=128 \
1414
+data.max_length=4096 \
1515
+data.max_steps=30 \

examples/grpo_trainer/sci_easy.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ray start --head
99
python -m verl.trainer.main_ppo_sci \
1010
algorithm.adv_estimator=grpo \
1111
data.train_files=get_data/rl/sci_train.json \
12-
data.val_files=get_data/rl/sci_seen.json \
12+
data.val_files=get_data/rl/sci_dev.json \
1313
data.train_batch_size=64 \
1414
+data.max_length=4096 \
1515
+data.max_steps=30 \

examples/grpo_trainer/sci_normal.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ set -x
33

44
export system_prompt='You are a helpful assistant to do some scientific experiment in an environment.\nYou should explore the environment and find the items you need to complete the experiment.\n\nIn the environment, there are several rooms: kitchen, foundry, workshop, bathroom, outside, living room, bedroom, greenhouse, art studio, hallway.\nThe available actions are:\nactivate OBJ\nclose OBJ\nconnect OBJ to OBJ\ndeactivate OBJ\ndisconnect OBJ\ndunk OBJ in OBJ\neat OBJ\nflush OBJ\nfocus on OBJ\ngo LOC\ninventory\nlook around\nlook at OBJ\nlook in OBJ\nmix OBJ\nmove OBJ to OBJ\nopen OBJ\npick up OBJ\npour OBJ in OBJ\nput down OBJ\nread OBJ\nuse OBKJ on OBJ\nwait: wait 10 steps\nwait1: wait 1 step\ntask: check your task\ndone: indicate that you believe the task is complete\nWhen arrive a new location, you should use look around to check the OBj you can interact with.\nUse focus on OBJ only neccessary as incorrect use will cause environment ends.\nDo not proceed with any further exploration or actions until you receive the feedback from the environment after your action.\nYour response should use the following format:\n\nThought: <your thoughts>\nAction: <your next action>'
55
start_port=8000
6-
model_path='/path/to/Qwen/Qwen2.5-7B-Instruct'
6+
model_path='/path/to/Qwen2.5-7B-Instruct'
77

88
ray start --head
99
python -m verl.trainer.main_ppo_sci \
1010
algorithm.adv_estimator=grpo \
1111
data.train_files=get_data/rl/sci_train.json \
12-
data.val_files=get_data/rl/sci_seen.json \
12+
data.val_files=get_data/rl/sci_dev.json \
1313
data.train_batch_size=64 \
1414
+data.max_length=4096 \
1515
+data.max_steps=30 \

get_data/get_data_for_training.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ python -m utils.modify_alf_sft --input ${ORIGIN_SFT}/data/sciworld_sft.json --ou
1212

1313

1414
# 2. task for rl
15-
ALF_GAMEFILE_PATH='~/.cache/alfworld'
15+
ALF_GAMEFILE_PATH='./alfworld'
1616
python -m utils.generate_alf_indice --input "${ALF_GAMEFILE_PATH}/json_2.1.1/train" --output ./rl/alf_train.json
1717
python -m utils.generate_alf_indice --input "${ALF_GAMEFILE_PATH}/json_2.1.1/valid_seen" --output ./rl/alf_valid_seen.json
1818
python -m utils.generate_alf_indice --input "${ALF_GAMEFILE_PATH}/json_2.1.1/valid_unseen" --output ./rl/alf_valid_unseen.json

verl/workers/rollout/vllm_rollout/alf_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf
129129
max_num_batched_tokens=max_num_batched_tokens,
130130
enable_chunked_prefill=config.enable_chunked_prefill,
131131
enable_prefix_caching=True,
132+
seed=42,
132133
)
133134
else:
134135
raise NotImplementedError

verl/workers/rollout/vllm_rollout/sci_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf
129129
max_num_batched_tokens=max_num_batched_tokens,
130130
enable_chunked_prefill=config.enable_chunked_prefill,
131131
enable_prefix_caching=True,
132+
seed=42,
132133
)
133134
else:
134135
raise NotImplementedError

0 commit comments

Comments
 (0)