Skip to content

Commit 7e104a5

Browse files
author
Georg Martius
committed
The default env is continous now. There is a function to convert discrete actions into continuous ones.
1 parent 48372e0 commit 7e104a5

File tree

2 files changed

+145
-36
lines changed

2 files changed

+145
-36
lines changed

Laser-Hockey-Env.ipynb

Lines changed: 121 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 22,
66
"metadata": {
77
"ExecuteTime": {
8-
"end_time": "2018-12-21T11:13:01.818505Z",
9-
"start_time": "2018-12-21T11:13:01.670659Z"
8+
"end_time": "2018-12-21T12:23:36.441678Z",
9+
"start_time": "2018-12-21T12:23:36.437323Z"
1010
}
1111
},
1212
"outputs": [],
@@ -19,11 +19,11 @@
1919
},
2020
{
2121
"cell_type": "code",
22-
"execution_count": 8,
22+
"execution_count": 23,
2323
"metadata": {
2424
"ExecuteTime": {
25-
"end_time": "2018-12-21T11:13:16.094022Z",
26-
"start_time": "2018-12-21T11:13:16.090695Z"
25+
"end_time": "2018-12-21T12:23:36.859856Z",
26+
"start_time": "2018-12-21T12:23:36.854044Z"
2727
}
2828
},
2929
"outputs": [],
@@ -45,11 +45,11 @@
4545
},
4646
{
4747
"cell_type": "code",
48-
"execution_count": 13,
48+
"execution_count": 25,
4949
"metadata": {
5050
"ExecuteTime": {
51-
"end_time": "2018-12-21T11:20:15.026155Z",
52-
"start_time": "2018-12-21T11:20:15.019530Z"
51+
"end_time": "2018-12-21T12:23:58.180935Z",
52+
"start_time": "2018-12-21T12:23:58.169092Z"
5353
}
5454
},
5555
"outputs": [
@@ -59,7 +59,7 @@
5959
"<module 'laser_hockey_env' from '/home/georg/src/python/laser-hockey-env/laser_hockey_env.py'>"
6060
]
6161
},
62-
"execution_count": 13,
62+
"execution_count": 25,
6363
"metadata": {},
6464
"output_type": "execute_result"
6565
}
@@ -70,16 +70,16 @@
7070
},
7171
{
7272
"cell_type": "code",
73-
"execution_count": 16,
73+
"execution_count": 26,
7474
"metadata": {
7575
"ExecuteTime": {
76-
"end_time": "2018-12-21T11:20:27.241126Z",
77-
"start_time": "2018-12-21T11:20:27.230640Z"
76+
"end_time": "2018-12-21T12:24:00.162508Z",
77+
"start_time": "2018-12-21T12:24:00.151792Z"
7878
}
7979
},
8080
"outputs": [],
8181
"source": [
82-
"env = lh.LaserHockeyEnvContinuous()"
82+
"env = lh.LaserHockeyEnv()"
8383
]
8484
},
8585
{
@@ -91,11 +91,11 @@
9191
},
9292
{
9393
"cell_type": "code",
94-
"execution_count": 20,
94+
"execution_count": 27,
9595
"metadata": {
9696
"ExecuteTime": {
97-
"end_time": "2018-12-21T11:22:12.232033Z",
98-
"start_time": "2018-12-21T11:22:12.213256Z"
97+
"end_time": "2018-12-21T12:24:01.084415Z",
98+
"start_time": "2018-12-21T12:24:00.981942Z"
9999
}
100100
},
101101
"outputs": [
@@ -105,7 +105,7 @@
105105
"True"
106106
]
107107
},
108-
"execution_count": 20,
108+
"execution_count": 27,
109109
"metadata": {},
110110
"output_type": "execute_result"
111111
}
@@ -216,7 +216,7 @@
216216
},
217217
"outputs": [],
218218
"source": [
219-
"env = lh.LaserHockeyEnvContinuous(mode=lh.LaserHockeyEnv.TRAIN_SHOOTING)"
219+
"env = lh.LaserHockeyEnv(mode=lh.LaserHockeyEnv.TRAIN_SHOOTING)"
220220
]
221221
},
222222
{
@@ -326,7 +326,7 @@
326326
},
327327
"outputs": [],
328328
"source": [
329-
"env = lh.LaserHockeyEnvContinuous(mode=lh.LaserHockeyEnv.TRAIN_DEFENCE)"
329+
"env = lh.LaserHockeyEnv(mode=lh.LaserHockeyEnv.TRAIN_DEFENCE)"
330330
]
331331
},
332332
{
@@ -387,6 +387,107 @@
387387
" obs_agent2 = env.obs_agent_two()\n",
388388
" if d: break"
389389
]
390+
},
391+
{
392+
"cell_type": "markdown",
393+
"metadata": {
394+
"ExecuteTime": {
395+
"end_time": "2018-12-20T20:37:41.013424Z",
396+
"start_time": "2018-12-20T20:37:41.009298Z"
397+
}
398+
},
399+
"source": [
400+
"# Using discrete actions"
401+
]
402+
},
403+
{
404+
"cell_type": "code",
405+
"execution_count": 28,
406+
"metadata": {
407+
"ExecuteTime": {
408+
"end_time": "2018-12-21T12:24:05.082438Z",
409+
"start_time": "2018-12-21T12:24:05.072962Z"
410+
}
411+
},
412+
"outputs": [
413+
{
414+
"data": {
415+
"text/plain": [
416+
"<module 'laser_hockey_env' from '/home/georg/src/python/laser-hockey-env/laser_hockey_env.py'>"
417+
]
418+
},
419+
"execution_count": 28,
420+
"metadata": {},
421+
"output_type": "execute_result"
422+
}
423+
],
424+
"source": [
425+
"reload(lh)"
426+
]
427+
},
428+
{
429+
"cell_type": "code",
430+
"execution_count": 29,
431+
"metadata": {
432+
"ExecuteTime": {
433+
"end_time": "2018-12-21T12:24:05.821293Z",
434+
"start_time": "2018-12-21T12:24:05.814344Z"
435+
}
436+
},
437+
"outputs": [],
438+
"source": [
439+
"env = lh.LaserHockeyEnv(mode=lh.LaserHockeyEnv.TRAIN_SHOOTING)"
440+
]
441+
},
442+
{
443+
"cell_type": "code",
444+
"execution_count": 32,
445+
"metadata": {
446+
"ExecuteTime": {
447+
"end_time": "2018-12-21T12:24:56.153283Z",
448+
"start_time": "2018-12-21T12:24:56.148110Z"
449+
}
450+
},
451+
"outputs": [],
452+
"source": [
453+
"import random"
454+
]
455+
},
456+
{
457+
"cell_type": "code",
458+
"execution_count": 35,
459+
"metadata": {
460+
"ExecuteTime": {
461+
"end_time": "2018-12-21T12:25:42.213473Z",
462+
"start_time": "2018-12-21T12:25:39.046109Z"
463+
}
464+
},
465+
"outputs": [
466+
{
467+
"name": "stdout",
468+
"output_type": "stream",
469+
"text": [
470+
"Player 2 scored\n"
471+
]
472+
}
473+
],
474+
"source": [
475+
"for _ in range(200):\n",
476+
" env.render()\n",
477+
" a1_discrete = random.randint(0,7)\n",
478+
" a1 = env.discrete_to_continous_action(a1_discrete)\n",
479+
" a2 = [0,0.,0] \n",
480+
" obs, r, d, info = env.step(np.hstack([a1,a2])) \n",
481+
" obs_agent2 = env.obs_agent_two()\n",
482+
" if d: break"
483+
]
484+
},
485+
{
486+
"cell_type": "code",
487+
"execution_count": null,
488+
"metadata": {},
489+
"outputs": [],
490+
"source": []
390491
}
391492
],
392493
"metadata": {

laser_hockey_env.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,8 @@ def __init__(self, mode = NORMAL):
103103
# y vel puck
104104
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(16,), dtype=np.float32)
105105

106-
if self.continuous:
107-
# linear force in (x,y)-direction and torque
108-
self.action_space = spaces.Box(-1, +1, (3*2,), dtype=np.float32)
109-
else:
110-
111-
self.action_space = spaces.Discrete(6*2)
106+
# linear force in (x,y)-direction and torque
107+
self.action_space = spaces.Box(-1, +1, (3*2,), dtype=np.float32)
112108

113109
self.reset()
114110

@@ -445,17 +441,33 @@ def _get_info(self):
445441
winner=self.winner
446442
)
447443

444+
445+
def discrete_to_continous_action(self, discrete_action):
446+
''' converts discrete actions into continuous ones (for each player)
447+
The actions allow only one operation each timestep, e.g. X or Y or angle change.
448+
This is surely limiting. Other discrete actions are possible
449+
Action 0: do nothing
450+
Action 1: -1 in x
451+
Action 2: 1 in x
452+
Action 3: -1 in y
453+
Action 4: 1 in y
454+
Action 5: -1 in angle
455+
Action 6: 1 in angle
456+
'''
457+
action_cont = [(discrete_action==1) * -1 + (discrete_action==2) * 1, # player x
458+
(discrete_action==3) * -1 + (discrete_action==4) * 1, # player y
459+
(discrete_action==5) * -1 + (discrete_action==6) * 1] # player angle
460+
461+
return action_cont
462+
463+
448464
def step(self, action):
449-
if self.continuous:
450-
action = np.clip(action, -1, +1).astype(np.float32)
451-
else:
452-
assert self.action_space.contains(action), "%r (%s) invalid " % (action, type(action))
453-
pass
465+
action = np.clip(action, -1, +1).astype(np.float32)
454466

455467
self._apply_action_with_max_speed(self.player1, action[:2], 10, True)
456468
self.player1.ApplyTorque(action[2] * TORQUEMULTIPLAYER, True)
457469
self._apply_action_with_max_speed(self.player2, action[3:5], 10, False)
458-
self.player2.ApplyTorque(action[5] * TORQUEMULTIPLAYER, True)
470+
self.player2.ApplyTorque(-action[5] * TORQUEMULTIPLAYER, True)
459471

460472
self.world.Step(self.timeStep, 6 * 30, 2 * 30)
461473

@@ -502,7 +514,3 @@ def close(self):
502514
if self.viewer is not None:
503515
self.viewer.close()
504516
self.viewer = None
505-
506-
507-
class LaserHockeyEnvContinuous(LaserHockeyEnv):
508-
continuous = True

0 commit comments

Comments
 (0)