Skip to content

Commit 84b7d89

Browse files
committed
make rewards configurable
1 parent ca42086 commit 84b7d89

File tree

5 files changed

+170
-16
lines changed

5 files changed

+170
-16
lines changed

pufferlib/config/ocean/impulse_wars.ini

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ continuous = False
3030
is_training = True
3131

3232
[train]
33-
total_timesteps = 100_000_000
33+
total_timesteps = 1_000_000_000
3434
checkpoint_interval = 250
3535

3636
learning_rate = 0.005
@@ -47,6 +47,78 @@ max = 512
4747
mean = 128
4848
scale = auto
4949

50+
# reward parameters
51+
[sweep.env.reward_win]
52+
distribution = uniform
53+
min = 0.0
54+
mean = 2.0
55+
max = 5.0
56+
scale = auto
57+
58+
[sweep.env.reward_self_kill]
59+
distribution = uniform
60+
min = -3.0
61+
mean = -1.0
62+
max = 0.0
63+
scale = auto
64+
65+
[sweep.env.reward_enemy_death]
66+
distribution = uniform
67+
min = 0.0
68+
mean = 1.0
69+
max = 3.0
70+
scale = auto
71+
72+
[sweep.env.reward_kill]
73+
distribution = uniform
74+
min = 0.0
75+
mean = 1.0
76+
max = 3.0
77+
scale = auto
78+
79+
[sweep.env.reward_death]
80+
distribution = uniform
81+
min = -1.0
82+
mean = -0.25
83+
max = 0.0
84+
scale = auto
85+
86+
[sweep.env.reward_energy_emptied]
87+
distribution = uniform
88+
min = -2.0
89+
mean = -0.75
90+
max = 0.0
91+
scale = auto
92+
93+
[sweep.env.reward_weapon_pickup]
94+
distribution = uniform
95+
min = 0.0
96+
mean = 0.5
97+
max = 3.0
98+
scale = auto
99+
100+
[sweep.env.reward_shield_break]
101+
distribution = uniform
102+
min = 0.0
103+
mean = 0.5
104+
max = 3.0
105+
scale = auto
106+
107+
[sweep.env.reward_shot_hit_coef]
108+
distribution = log_normal
109+
min = 0.0005
110+
mean = 0.005
111+
max = 0.05
112+
scale = auto
113+
114+
[sweep.env.reward_explosion_hit_coef]
115+
distribution = log_normal
116+
min = 0.0005
117+
mean = 0.005
118+
max = 0.05
119+
scale = auto
120+
121+
# hyperparameters
50122
[sweep.train.total_timesteps]
51123
distribution = log_normal
52124
min = 250_000_000

pufferlib/ocean/impulse_wars/binding.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ static int my_init(iwEnv *e, PyObject *args, PyObject *kwargs) {
100100
(bool)unpack(kwargs, "is_training"),
101101
(bool)unpack(kwargs, "continuous")
102102
);
103+
setRewards(
104+
e,
105+
(float)unpack(kwargs, "reward_win"),
106+
(float)unpack(kwargs, "reward_self_kill"),
107+
(float)unpack(kwargs, "reward_enemy_death"),
108+
(float)unpack(kwargs, "reward_enemy_kill"),
109+
0.0f, // teammate death punishment
110+
0.0f, // teammate kill punishment
111+
(float)unpack(kwargs, "reward_death"),
112+
(float)unpack(kwargs, "reward_energy_emptied"),
113+
(float)unpack(kwargs, "reward_weapon_pickup"),
114+
(float)unpack(kwargs, "reward_shield_break"),
115+
(float)unpack(kwargs, "reward_shot_hit_coef"),
116+
(float)unpack(kwargs, "reward_explosion_hit_coef")
117+
);
103118
return 0;
104119
}
105120

@@ -131,6 +146,8 @@ static int my_log(PyObject *dict, Log *log) {
131146
assign_to_dict(dict, droneLog(buf, i, "total_bursts"), log->stats[i].totalBursts);
132147
assign_to_dict(dict, droneLog(buf, i, "bursts_hit"), log->stats[i].burstsHit);
133148
assign_to_dict(dict, droneLog(buf, i, "energy_emptied"), log->stats[i].energyEmptied);
149+
assign_to_dict(dict, droneLog(buf, i, "shields_broken"), log->stats[i].shieldsBroken);
150+
assign_to_dict(dict, droneLog(buf, i, "own_shield_broken"), log->stats[i].ownShieldBroken);
134151
assign_to_dict(dict, droneLog(buf, i, "self_kills"), log->stats[i].selfKills);
135152
assign_to_dict(dict, droneLog(buf, i, "kills"), log->stats[i].kills);
136153
assign_to_dict(dict, droneLog(buf, i, "wins"), log->stats[i].wins);

pufferlib/ocean/impulse_wars/env.h

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,19 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
539539
e->sittingDuck = sittingDuck;
540540
e->isTraining = isTraining;
541541

542+
e->winReward = WIN_REWARD;
543+
e->selfKillPunishment = SELF_KILL_PUNISHMENT;
544+
e->enemyDeathReward = ENEMY_DEATH_REWARD;
545+
e->enemyKillReward = ENEMY_KILL_REWARD;
546+
e->teammateDeathPunishment = TEAMMATE_DEATH_PUNISHMENT;
547+
e->teammateKillPunishment = TEAMMATE_KILL_PUNISHMENT;
548+
e->deathPunishment = DEATH_PUNISHMENT;
549+
e->energyEmptiedPunishment = ENERGY_EMPTY_PUNISHMENT;
550+
e->weaponPickupReward = WEAPON_PICKUP_REWARD;
551+
e->shieldBreakReward = SHIELD_BREAK_REWARD;
552+
e->shotHitRewardCoef = SHOT_HIT_REWARD_COEF;
553+
e->explosionHitRewardCoef = EXPLOSION_HIT_REWARD_COEF;
554+
542555
e->obsBytes = obsBytes(e->numDrones);
543556
e->discreteObsBytes = alignedSize(discreteObsSize(e->numDrones) * sizeof(uint8_t), sizeof(float));
544557

@@ -586,6 +599,21 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
586599
return e;
587600
}
588601

602+
void setRewards(iwEnv *e, float winReward, float selfKillPunishment, float enemyDeathReward, float enemyKillReward, float teammateDeathPunishment, float teammateKillPunishment, float deathPunishment, float energyEmptiedPunishment, float weaponPickupReward, float shieldBreakReward, float shotHitRewardCoef, float explosionHitRewardCoef) {
603+
e->winReward = winReward;
604+
e->selfKillPunishment = selfKillPunishment;
605+
e->enemyDeathReward = enemyDeathReward;
606+
e->enemyKillReward = enemyKillReward;
607+
e->teammateDeathPunishment = teammateDeathPunishment;
608+
e->teammateKillPunishment = teammateKillPunishment;
609+
e->deathPunishment = deathPunishment;
610+
e->energyEmptiedPunishment = energyEmptiedPunishment;
611+
e->weaponPickupReward = weaponPickupReward;
612+
e->shieldBreakReward = shieldBreakReward;
613+
e->shotHitRewardCoef = shotHitRewardCoef;
614+
e->explosionHitRewardCoef = explosionHitRewardCoef;
615+
}
616+
589617
void clearEnv(iwEnv *e) {
590618
// rewards get cleared in stepEnv every step
591619
// memset(e->masks, 1, e->numAgents * sizeof(uint8_t));
@@ -684,15 +712,15 @@ float computeReward(iwEnv *e, droneEntity *drone) {
684712
float reward = 0.0f;
685713

686714
if (drone->energyFullyDepleted && drone->energyRefillWait == DRONE_ENERGY_REFILL_EMPTY_WAIT) {
687-
reward += ENERGY_EMPTY_PUNISHMENT;
715+
reward += e->energyEmptiedPunishment;
688716
}
689717

690718
// only reward picking up a weapon if the standard weapon was
691719
// previously held; every weapon is better than the standard
692720
// weapon, but other weapons are situational better so don't
693721
// reward switching a non-standard weapon
694722
if (drone->stepInfo.pickedUpWeapon && drone->stepInfo.prevWeapon == STANDARD_WEAPON) {
695-
reward += WEAPON_PICKUP_REWARD;
723+
reward += e->weaponPickupReward;
696724
}
697725

698726
for (uint8_t i = 0; i < e->numDrones; i++) {
@@ -704,34 +732,34 @@ float computeReward(iwEnv *e, droneEntity *drone) {
704732

705733
// TODO: punish for hitting teammates?
706734
if (drone->stepInfo.shotHit[i] != 0.0f && !onTeam) {
707-
reward += drone->stepInfo.shotHit[i] * SHOT_HIT_REWARD_COEF;
735+
reward += drone->stepInfo.shotHit[i] * e->shotHitRewardCoef;
708736
}
709737
if (drone->stepInfo.explosionHit[i] != 0.0f && !onTeam) {
710-
reward += drone->stepInfo.explosionHit[i] * EXPLOSION_HIT_REWARD_COEF;
738+
reward += drone->stepInfo.explosionHit[i] * e->explosionHitRewardCoef;
711739
}
712740
if (drone->stepInfo.brokeShield[i] && !onTeam) {
713-
reward += SHIELD_BREAK_REWARD;
741+
reward += e->shieldBreakReward;
714742
}
715743

716744
if (e->numAgents == e->numDrones) {
717745
if (drone->stepInfo.shotTaken[i] != 0) {
718-
reward -= drone->stepInfo.shotTaken[i] * SHOT_HIT_REWARD_COEF;
746+
reward -= drone->stepInfo.shotTaken[i] * e->shotHitRewardCoef;
719747
}
720748
if (drone->stepInfo.explosionTaken[i]) {
721-
reward -= drone->stepInfo.explosionTaken[i] * EXPLOSION_HIT_REWARD_COEF;
749+
reward -= drone->stepInfo.explosionTaken[i] * e->explosionHitRewardCoef;
722750
}
723751
}
724752

725753
if (enemyDrone->dead && enemyDrone->diedThisStep) {
726754
if (!onTeam) {
727-
reward += ENEMY_DEATH_REWARD;
755+
reward += e->enemyDeathReward;
728756
if (drone->killed[i]) {
729-
reward += ENEMY_KILL_REWARD;
757+
reward += e->enemyKillReward;
730758
}
731759
} else {
732-
reward += TEAMMATE_DEATH_PUNISHMENT;
760+
reward += e->teammateDeathPunishment;
733761
if (drone->killed[i]) {
734-
reward += TEAMMATE_KILL_PUNISHMENT;
762+
reward += e->teammateKillPunishment;
735763
}
736764
}
737765
continue;
@@ -756,19 +784,19 @@ const float REWARD_EPS = 1.0e-6f;
756784

757785
void computeRewards(iwEnv *e, const bool roundOver, const int8_t winner, const int8_t winningTeam) {
758786
if (roundOver && winner != -1 && winner < e->numAgents) {
759-
e->rewards[winner] += WIN_REWARD;
787+
e->rewards[winner] += e->winReward;
760788
}
761789

762790
for (uint8_t i = 0; i < e->numDrones; i++) {
763791
float reward = 0.0f;
764792
droneEntity *drone = safe_array_get_at(e->drones, i);
765793
reward = computeReward(e, drone);
766794
if (!drone->dead && roundOver && winningTeam == drone->team) {
767-
reward += WIN_REWARD;
795+
reward += e->winReward;
768796
} else if (drone->diedThisStep) {
769-
reward = DEATH_PUNISHMENT;
797+
reward = e->deathPunishment;
770798
if (drone->killedBy == drone->idx) {
771-
reward += SELF_KILL_PUNISHMENT;
799+
reward += e->selfKillPunishment;
772800
}
773801
}
774802
if (i < e->numAgents) {
@@ -973,6 +1001,10 @@ void addLog(iwEnv *e, Log *log) {
9731001
e->log.stats[j].totalBursts += log->stats[j].totalBursts;
9741002
e->log.stats[j].burstsHit += log->stats[j].burstsHit;
9751003
e->log.stats[j].energyEmptied += log->stats[j].energyEmptied;
1004+
e->log.stats[j].shieldsBroken += log->stats[j].shieldsBroken;
1005+
e->log.stats[j].ownShieldBroken += log->stats[j].ownShieldBroken;
1006+
e->log.stats[j].selfKills += log->stats[j].selfKills;
1007+
e->log.stats[j].kills += log->stats[j].kills;
9761008

9771009
for (uint8_t k = 0; k < NUM_WEAPONS; k++) {
9781010
e->log.stats[j].shotsFired[k] += log->stats[j].shotsFired[k];

pufferlib/ocean/impulse_wars/impulse_wars.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ def __init__(
2828
continuous: bool = False,
2929
is_training: bool = True,
3030
human_control: bool = False,
31+
reward_win: float = 2.0,
32+
reward_self_kill: float = -1.0,
33+
reward_enemy_death: float = 1.0,
34+
reward_enemy_kill: float = 1.0,
35+
reward_death: float = -0.25,
36+
reward_energy_emptied: float = -0.75,
37+
reward_weapon_pickup: float = 0.5,
38+
reward_shield_break: float = 0.5,
39+
reward_shot_hit_coef: float = 0.005,
40+
reward_explosion_hit_coef: float = 0.005,
3141
seed: int = 0,
3242
render: bool = False,
3343
report_interval: int = 64,
@@ -98,6 +108,16 @@ def __init__(
98108
sitting_duck=sitting_duck,
99109
is_training=is_training,
100110
continuous=continuous,
111+
reward_win=reward_win,
112+
reward_self_kill=reward_self_kill,
113+
reward_enemy_death=reward_enemy_death,
114+
reward_enemy_kill=reward_enemy_kill,
115+
reward_death=reward_death,
116+
reward_energy_emptied=reward_energy_emptied,
117+
reward_weapon_pickup=reward_weapon_pickup,
118+
reward_shield_break=reward_shield_break,
119+
reward_shot_hit_coef=reward_shot_hit_coef,
120+
reward_explosion_hit_coef=reward_explosion_hit_coef,
101121
)
102122

103123
binding.shared(self.c_envs)

pufferlib/ocean/impulse_wars/types.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,19 @@ typedef struct iwEnv {
412412
bool sittingDuck;
413413
bool isTraining;
414414

415+
float winReward;
416+
float selfKillPunishment;
417+
float enemyDeathReward;
418+
float enemyKillReward;
419+
float teammateDeathPunishment;
420+
float teammateKillPunishment;
421+
float deathPunishment;
422+
float energyEmptiedPunishment;
423+
float weaponPickupReward;
424+
float shieldBreakReward;
425+
float shotHitRewardCoef;
426+
float explosionHitRewardCoef;
427+
415428
uint16_t obsBytes;
416429
uint16_t discreteObsBytes;
417430
bool continuousActions;

0 commit comments

Comments
 (0)