Skip to content

Commit 527155d

Browse files
authored
Baselines for the Halite Competition (#706)
* add baselines for the halite competition * modify the code format * delete useless files * delete __init__.py files * modify the code style and delete useless files * modify code style * modify readme.md * reduce the file size of animation.gif * recover the size of animaion.gif in Halite-Competition/torch * modify the config.py and encode_model.py * modify the backend of parl to be torch and the annotation * modify the annotation and delete useless codes * modify the code style by yapf * modify the kaggle_environments to zerosum_env * reduce the size of test ipython notebook * fix some typos * fix some typos * delete assets and change the address of gifs and imgs to that of parl-experiments * update requirement.txt and use parl.utils.logger for logging * better code style in requirement.txt
1 parent 56cf3f5 commit 527155d

34 files changed

+6069
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# 浦发赛事强化学习基线方案
2+
基于PARL框架,我们提供了一个PPO算法的基线方案
3+
4+
## 目录
5+
* config.py : 参数配置
6+
* train.py : 训练脚本
7+
* test.py : 测试脚本
8+
* submission.py : 提交示例
9+
* rl_trainer
10+
* model.py : 定义 actor 和 critic 网络架构
11+
* agent.py : 负责算法和环境交互,包括将数据提供给算法训练
12+
* algorithm.py : PPO算法实现
13+
* controller.py : 跟踪每艘飞船的状态,设计奖励并收集训练数据
14+
* policy.py : 定义基于规则的策略以用于控制基地
15+
* obs_parser.py : 设计每艘飞船的状态
16+
17+
18+
## 基线设计
19+
我们使用PPO算法来控制每艘飞船,其中所有飞船都共享同一个模型参数。
20+
每艘飞船的目标是尽可能快地收集K个单位的砂金(K为超参数),飞船采集完成后则返航到基地,将砂金放置基地后则开始新一轮的采集过程,另外当交互过程即将结束(达到最大的交互步数),飞船也会被强制返航。换而言之,飞船采集的过程是由模型来控制,其余过程则由规则控制。对于基地,我们则使用规则来控制,基地的目标是尽可能快地生产M艘飞船(M为超参数)。
21+
22+
23+
## 快速开始
24+
创建并激活一个虚拟python环境
25+
```shell
26+
conda create -n halite python==3.6
27+
28+
source activate halite
29+
```
30+
31+
安装依赖
32+
```shell
33+
pip install -r requirements.txt
34+
```
35+
36+
## 训练
37+
在 config.py 文件中修改超参数后并运行以下命令:
38+
```shell
39+
python train.py
40+
```
41+
42+
## 测试
43+
当训练完成后,在 test.py 中修改你的模型加载路径后运行脚本来测试你的模型效果。
44+
```shell
45+
python test.py
46+
```
47+
48+
需要注意的是,此测试脚本使用了一个内置的 random agent 作为对手。如果你需要对比其他智能体的话则需要修改 “random” 为对应的智能体方法。当你提交模型和方案到平台前,也可以使用此脚本来测试你的代码中是否无误。
49+
50+
51+
## 结果
52+
以下图片展示了PPO算法的学习效果。目前,我们只在某个固定种子下训练模型,并且选取了一个随机智能体作为对手。为了在赛事中得到较好的名次,选手应该训练出一个更为鲁棒的模型(如应对不同砂金分布的环境和 1vs1, 1vs3场景)。
53+
![learning curve](https://github.com/benchmarking-rl/PARL-experiments/blob/master/Baselines/Halite_Competition/paddle/learning_curve.jpg?raw=true)
54+
55+
## 可视化
56+
如果你想查看经渲染后的对战效果,首先需要激活Jupyter Notebook环境并打开test.ipynb,随后运行其中代码即可看到动画效果。
57+
![animation](https://github.com/benchmarking-rl/PARL-experiments/blob/master/Baselines/Halite_Competition/paddle/animation.gif?raw=true)
58+
59+
## 提交
60+
目前选手们只能提交一个文件到平台上,因此选手需要将需要用到的函数和模型都放置到同一个文件中。为了在文件中加载模型,选手需要先将模型编码成字节串然后放到文件中,在需要加载模型的地方将字节串解码。选手可以参考 encode_model.py 查看如何编码模型,参考 submission.py 文件查看提交范例和加载模型。
61+
62+
需要注意的是,评分系统只会调用提交文件的最后一个函数方法。因此选手需要将智能体给出动作的函数方法放在提交文件的最后,此方法接收 observation 和 configuration作为输入,给出每艘飞船和基地的动作,具体选手可查看 submission.py 文件。
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
config = {
16+
17+
# configuration for env
18+
"board_size": 21,
19+
20+
# configuration for training
21+
"episodes": 100000,
22+
"batch_size": 128,
23+
"train_times": 2,
24+
"gamma": 0.997,
25+
"lr": 0.0001,
26+
"test_every_episode": 100,
27+
28+
# configuration for ppo algorithm
29+
"vf_loss_coef": 1,
30+
"ent_coef": 0.01,
31+
32+
# configuration for the observation of ships
33+
"world_dim": 5 * 21 * 21,
34+
"ship_obs_dim": 6,
35+
"ship_act_dim": 5,
36+
"ship_max_step": 10000,
37+
38+
# the number of halite we want the ships to obtain (e.g K)
39+
"num_halite": 100,
40+
41+
# the maximum number of ships (e.g M)
42+
"num_ships": 10,
43+
44+
# seed for training
45+
"seed": 5609,
46+
47+
# configuration for logging
48+
"log_path": './train_log/',
49+
"save_path": './save_model/',
50+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import base64
16+
import pickle
17+
import paddle
18+
19+
if __name__ == '__main__':
20+
21+
model = paddle.load('./model/latest_ship_model.pth')
22+
actor = model['actor']
23+
24+
for name, param in actor.items():
25+
actor[name] = param.numpy()
26+
27+
model_byte = base64.b64encode(pickle.dumps(actor))
28+
with open('./model/actor.txt', 'wb') as f:
29+
f.write(model_byte)
Binary file not shown.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
parl>=2.0.0
2+
paddlepaddle>=2.0.0
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import parl
17+
import paddle
18+
19+
20+
class Agent(parl.Agent):
21+
"""Agent.
22+
Args:
23+
algorithm (`parl.Algorithm`): algorithm to be used in this agent.
24+
"""
25+
26+
def __init__(self, algorithm):
27+
28+
self.alg = algorithm
29+
30+
def learn(self, obs, act, value, returns, log_prob, adv):
31+
"""Updating network
32+
Args:
33+
obs (np.array): representation of current observation
34+
act (np.array): current action
35+
value (np.array): state value
36+
returns (np.array): discounted return
37+
log_prob (np.array): the log probabilities of action
38+
adv (np.array): advantage value
39+
"""
40+
41+
obs = paddle.to_tensor(obs, dtype=paddle.float32)
42+
act = paddle.to_tensor(act, dtype=paddle.int32)
43+
value = paddle.to_tensor(value, dtype=paddle.float32)
44+
returns = paddle.to_tensor(returns, dtype=paddle.float32)
45+
log_prob = paddle.to_tensor(log_prob, dtype=paddle.float32)
46+
adv = paddle.to_tensor(adv, dtype=paddle.float32)
47+
48+
value_loss, action_loss, entropy = self.alg.learn(
49+
obs, act, value, returns, log_prob, adv)
50+
51+
return value_loss, action_loss, entropy
52+
53+
def predict(self, state):
54+
"""Predict action
55+
Args:
56+
state (np.array): representation of current state
57+
58+
Return:
59+
action (np.array): action to be executed
60+
"""
61+
62+
state_tensor = paddle.to_tensor(state, dtype=paddle.float32)
63+
64+
with paddle.no_grad():
65+
66+
action = self.alg.predict(state_tensor).cpu().numpy()
67+
68+
return action
69+
70+
def sample(self, state):
71+
"""Sampling action
72+
Args:
73+
state (np.array): representation of current state
74+
Return:
75+
action (np.array): action to be executed
76+
"""
77+
78+
state_tensor = paddle.to_tensor(state, dtype=paddle.float32)
79+
80+
with paddle.no_grad():
81+
82+
value, action, action_log_prob = self.alg.sample(state_tensor)
83+
84+
value = value.detach().cpu().numpy().flatten()
85+
action = action.detach().cpu().numpy()
86+
action_log_prob = action_log_prob.cpu().numpy()
87+
88+
return value, action, action_log_prob
89+
90+
def value(self, state):
91+
"""Predict the critic value
92+
Args:
93+
state (np.array): representation of current state
94+
Return:
95+
value (np.array): state value
96+
"""
97+
98+
state_tensor = paddle.to_tensor(state, dtype=paddle.float32)
99+
100+
with paddle.no_grad():
101+
102+
value = self.alg.value(state_tensor).cpu().numpy()
103+
104+
return value
105+
106+
def save(self, model_path):
107+
"""Save Model
108+
Args:
109+
model_path (str): the path to save model
110+
"""
111+
sep = os.sep
112+
dirname = sep.join(model_path.split(sep)[:-1])
113+
if dirname != '' and not os.path.exists(dirname):
114+
os.makedirs(dirname)
115+
model_dict = {}
116+
model_dict["critic"] = self.alg.critic.state_dict()
117+
model_dict["actor"] = self.alg.actor.state_dict()
118+
model_dict["optim"] = self.alg.optim.state_dict()
119+
paddle.save(model_dict, model_path)
120+
121+
def restore(self, model_path):
122+
"""Restore model
123+
Args:
124+
model_path (str): the path to restore model
125+
"""
126+
model_dict = paddle.load(model_path)
127+
self.alg.critic.set_state_dict(model_dict["critic"])
128+
self.alg.actor.set_state_dict(model_dict["actor"])
129+
self.alg.optim.set_state_dict(model_dict["optim"])

0 commit comments

Comments
 (0)