Skip to content

Commit 0b167ed

Browse files
author
Joseph Suarez
committed
Merge branch 'iw_stuff' of https://github.com/capnspacehook/PufferLib into capnspacehook-iw-stuff
2 parents bcb2fc5 + 705c12f commit 0b167ed

File tree

9 files changed

+541
-188
lines changed

9 files changed

+541
-188
lines changed

pufferlib/config/ocean/impulse_wars.ini

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ continuous = False
3030
is_training = True
3131

3232
[train]
33-
total_timesteps = 100_000_000
33+
total_timesteps = 1_000_000_000
3434
checkpoint_interval = 250
3535

3636
learning_rate = 0.005
@@ -47,6 +47,78 @@ max = 512
4747
mean = 128
4848
scale = auto
4949

50+
# reward parameters
51+
[sweep.env.reward_win]
52+
distribution = uniform
53+
min = 0.0
54+
mean = 2.0
55+
max = 5.0
56+
scale = auto
57+
58+
[sweep.env.reward_self_kill]
59+
distribution = uniform
60+
min = -3.0
61+
mean = -1.0
62+
max = 0.0
63+
scale = auto
64+
65+
[sweep.env.reward_enemy_death]
66+
distribution = uniform
67+
min = 0.0
68+
mean = 1.0
69+
max = 3.0
70+
scale = auto
71+
72+
[sweep.env.reward_kill]
73+
distribution = uniform
74+
min = 0.0
75+
mean = 1.0
76+
max = 3.0
77+
scale = auto
78+
79+
[sweep.env.reward_death]
80+
distribution = uniform
81+
min = -1.0
82+
mean = -0.25
83+
max = 0.0
84+
scale = auto
85+
86+
[sweep.env.reward_energy_emptied]
87+
distribution = uniform
88+
min = -2.0
89+
mean = -0.75
90+
max = 0.0
91+
scale = auto
92+
93+
[sweep.env.reward_weapon_pickup]
94+
distribution = uniform
95+
min = 0.0
96+
mean = 0.5
97+
max = 3.0
98+
scale = auto
99+
100+
[sweep.env.reward_shield_break]
101+
distribution = uniform
102+
min = 0.0
103+
mean = 0.5
104+
max = 3.0
105+
scale = auto
106+
107+
[sweep.env.reward_shot_hit_coef]
108+
distribution = log_normal
109+
min = 0.0005
110+
mean = 0.005
111+
max = 0.05
112+
scale = auto
113+
114+
[sweep.env.reward_explosion_hit_coef]
115+
distribution = log_normal
116+
min = 0.0005
117+
mean = 0.005
118+
max = 0.05
119+
scale = auto
120+
121+
# hyperparameters
50122
[sweep.train.total_timesteps]
51123
distribution = log_normal
52124
min = 250_000_000

pufferlib/ocean/impulse_wars/binding.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ static int my_init(iwEnv *e, PyObject *args, PyObject *kwargs) {
100100
(bool)unpack(kwargs, "is_training"),
101101
(bool)unpack(kwargs, "continuous")
102102
);
103+
setRewards(
104+
e,
105+
(float)unpack(kwargs, "reward_win"),
106+
(float)unpack(kwargs, "reward_self_kill"),
107+
(float)unpack(kwargs, "reward_enemy_death"),
108+
(float)unpack(kwargs, "reward_enemy_kill"),
109+
0.0f, // teammate death punishment
110+
0.0f, // teammate kill punishment
111+
(float)unpack(kwargs, "reward_death"),
112+
(float)unpack(kwargs, "reward_energy_emptied"),
113+
(float)unpack(kwargs, "reward_weapon_pickup"),
114+
(float)unpack(kwargs, "reward_shield_break"),
115+
(float)unpack(kwargs, "reward_shot_hit_coef"),
116+
(float)unpack(kwargs, "reward_explosion_hit_coef")
117+
);
103118
return 0;
104119
}
105120

@@ -131,6 +146,11 @@ static int my_log(PyObject *dict, Log *log) {
131146
assign_to_dict(dict, droneLog(buf, i, "total_bursts"), log->stats[i].totalBursts);
132147
assign_to_dict(dict, droneLog(buf, i, "bursts_hit"), log->stats[i].burstsHit);
133148
assign_to_dict(dict, droneLog(buf, i, "energy_emptied"), log->stats[i].energyEmptied);
149+
assign_to_dict(dict, droneLog(buf, i, "shields_broken"), log->stats[i].shieldsBroken);
150+
assign_to_dict(dict, droneLog(buf, i, "own_shield_broken"), log->stats[i].ownShieldBroken);
151+
assign_to_dict(dict, droneLog(buf, i, "self_kills"), log->stats[i].selfKills);
152+
assign_to_dict(dict, droneLog(buf, i, "kills"), log->stats[i].kills);
153+
assign_to_dict(dict, droneLog(buf, i, "unknown_kills"), log->stats[i].unknownKills);
134154
assign_to_dict(dict, droneLog(buf, i, "wins"), log->stats[i].wins);
135155

136156
// useful for debugging weapon balance, but really slows down

pufferlib/ocean/impulse_wars/env.h

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ void computeObs(iwEnv *e) {
367367
continue;
368368
}
369369

370-
if (agentDrone->stepInfo.shotHit[i]) {
370+
if (agentDrone->stepInfo.shotHit[i] != 0.0f) {
371371
hitShot = true;
372372
}
373-
if (agentDrone->stepInfo.shotTaken[i]) {
373+
if (agentDrone->stepInfo.shotTaken[i] != 0.0f) {
374374
tookShot = true;
375375
}
376376

@@ -539,6 +539,19 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
539539
e->sittingDuck = sittingDuck;
540540
e->isTraining = isTraining;
541541

542+
e->winReward = WIN_REWARD;
543+
e->selfKillPunishment = SELF_KILL_PUNISHMENT;
544+
e->enemyDeathReward = ENEMY_DEATH_REWARD;
545+
e->enemyKillReward = ENEMY_KILL_REWARD;
546+
e->teammateDeathPunishment = TEAMMATE_DEATH_PUNISHMENT;
547+
e->teammateKillPunishment = TEAMMATE_KILL_PUNISHMENT;
548+
e->deathPunishment = DEATH_PUNISHMENT;
549+
e->energyEmptiedPunishment = ENERGY_EMPTY_PUNISHMENT;
550+
e->weaponPickupReward = WEAPON_PICKUP_REWARD;
551+
e->shieldBreakReward = SHIELD_BREAK_REWARD;
552+
e->shotHitRewardCoef = SHOT_HIT_REWARD_COEF;
553+
e->explosionHitRewardCoef = EXPLOSION_HIT_REWARD_COEF;
554+
542555
e->obsBytes = obsBytes(e->numDrones);
543556
e->discreteObsBytes = alignedSize(discreteObsSize(e->numDrones) * sizeof(uint8_t), sizeof(float));
544557

@@ -583,9 +596,28 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
583596
e->humanDroneInput = 0;
584597
e->connectedControllers = 0;
585598

599+
#ifndef NDEBUG
600+
create_array(&e->debugPoints, 4);
601+
#endif
602+
586603
return e;
587604
}
588605

606+
void setRewards(iwEnv *e, float winReward, float selfKillPunishment, float enemyDeathReward, float enemyKillReward, float teammateDeathPunishment, float teammateKillPunishment, float deathPunishment, float energyEmptiedPunishment, float weaponPickupReward, float shieldBreakReward, float shotHitRewardCoef, float explosionHitRewardCoef) {
607+
e->winReward = winReward;
608+
e->selfKillPunishment = selfKillPunishment;
609+
e->enemyDeathReward = enemyDeathReward;
610+
e->enemyKillReward = enemyKillReward;
611+
e->teammateDeathPunishment = teammateDeathPunishment;
612+
e->teammateKillPunishment = teammateKillPunishment;
613+
e->deathPunishment = deathPunishment;
614+
e->energyEmptiedPunishment = energyEmptiedPunishment;
615+
e->weaponPickupReward = weaponPickupReward;
616+
e->shieldBreakReward = shieldBreakReward;
617+
e->shotHitRewardCoef = shotHitRewardCoef;
618+
e->explosionHitRewardCoef = explosionHitRewardCoef;
619+
}
620+
589621
void clearEnv(iwEnv *e) {
590622
// rewards get cleared in stepEnv every step
591623
// memset(e->masks, 1, e->numAgents * sizeof(uint8_t));
@@ -673,36 +705,30 @@ void destroyEnv(iwEnv *e) {
673705
cc_array_destroy(e->dronePieces);
674706

675707
b2DestroyWorld(e->worldID);
708+
709+
#ifndef NDEBUG
710+
cc_array_destroy(e->debugPoints);
711+
#endif
676712
}
677713

678714
void resetEnv(iwEnv *e) {
679715
clearEnv(e);
680716
setupEnv(e);
681717
}
682718

683-
float computeShotReward(const droneEntity *drone, const weaponInformation *weaponInfo) {
684-
const float weaponForce = weaponInfo->fireMagnitude * weaponInfo->invMass;
685-
const float scaledForce = (weaponForce * (weaponForce * SHOT_HIT_REWARD_COEF)) + 0.25f;
686-
return scaledForce + computeHitStrength(drone);
687-
}
688-
689-
float computeExplosionReward(const droneEntity *drone) {
690-
return computeHitStrength(drone) * EXPLOSION_HIT_REWARD_COEF;
691-
}
692-
693719
float computeReward(iwEnv *e, droneEntity *drone) {
694720
float reward = 0.0f;
695721

696722
if (drone->energyFullyDepleted && drone->energyRefillWait == DRONE_ENERGY_REFILL_EMPTY_WAIT) {
697-
reward += ENERGY_EMPTY_PUNISHMENT;
723+
reward += e->energyEmptiedPunishment;
698724
}
699725

700726
// only reward picking up a weapon if the standard weapon was
701727
// previously held; every weapon is better than the standard
702728
// weapon, but other weapons are situational better so don't
703729
// reward switching a non-standard weapon
704730
if (drone->stepInfo.pickedUpWeapon && drone->stepInfo.prevWeapon == STANDARD_WEAPON) {
705-
reward += WEAPON_PICKUP_REWARD;
731+
reward += e->weaponPickupReward;
706732
}
707733

708734
for (uint8_t i = 0; i < e->numDrones; i++) {
@@ -712,51 +738,51 @@ float computeReward(iwEnv *e, droneEntity *drone) {
712738
droneEntity *enemyDrone = safe_array_get_at(e->drones, i);
713739
const bool onTeam = drone->team == enemyDrone->team;
714740

715-
if (drone->stepInfo.shotHit[i] != 0 && !onTeam) {
716-
// subtract 1 from the weapon type because 1 is added so we
717-
// can use 0 as no shot was hit
718-
const weaponInformation *weaponInfo = weaponInfos[drone->stepInfo.shotHit[i] - 1];
719-
reward += computeShotReward(enemyDrone, weaponInfo);
741+
// TODO: punish for hitting teammates?
742+
if (drone->stepInfo.shotHit[i] != 0.0f && !onTeam) {
743+
reward += drone->stepInfo.shotHit[i] * e->shotHitRewardCoef;
720744
}
721-
if (drone->stepInfo.explosionHit[i] && !onTeam) {
722-
reward += computeExplosionReward(enemyDrone);
745+
if (drone->stepInfo.explosionHit[i] != 0.0f && !onTeam) {
746+
reward += drone->stepInfo.explosionHit[i] * e->explosionHitRewardCoef;
747+
}
748+
if (drone->stepInfo.brokeShield[i] && !onTeam) {
749+
reward += e->shieldBreakReward;
723750
}
724751

725752
if (e->numAgents == e->numDrones) {
726-
if (drone->stepInfo.shotTaken[i] != 0 && !onTeam) {
727-
const weaponInformation *weaponInfo = weaponInfos[drone->stepInfo.shotTaken[i] - 1];
728-
reward -= computeShotReward(drone, weaponInfo) * 0.5f;
753+
if (drone->stepInfo.shotTaken[i] != 0) {
754+
reward -= drone->stepInfo.shotTaken[i] * e->shotHitRewardCoef;
729755
}
730-
if (drone->stepInfo.explosionTaken[i] && !onTeam) {
731-
reward -= computeExplosionReward(drone) * 0.5f;
756+
if (drone->stepInfo.explosionTaken[i]) {
757+
reward -= drone->stepInfo.explosionTaken[i] * e->explosionHitRewardCoef;
732758
}
733759
}
734760

735761
if (enemyDrone->dead && enemyDrone->diedThisStep) {
736762
if (!onTeam) {
737-
reward += ENEMY_DEATH_REWARD;
763+
reward += e->enemyDeathReward;
738764
if (drone->killed[i]) {
739-
reward += ENEMY_KILL_REWARD;
765+
reward += e->enemyKillReward;
740766
}
741767
} else {
742-
reward += TEAMMATE_DEATH_PUNISHMENT;
768+
reward += e->teammateDeathPunishment;
743769
if (drone->killed[i]) {
744-
reward += TEAMMATE_KILL_PUNISHMENT;
770+
reward += e->teammateKillPunishment;
745771
}
746772
}
747773
continue;
748774
}
749775

750-
const b2Vec2 enemyDirection = b2Normalize(b2Sub(enemyDrone->pos, drone->pos));
751-
const float velocityToEnemy = b2Dot(drone->lastVelocity, enemyDirection);
752-
const float enemyDistance = b2Distance(enemyDrone->pos, drone->pos);
753-
// stop rewarding approaching an enemy if they're very close
754-
// to avoid constant clashing; always reward approaching when
755-
// the current weapon is the shotgun, it greatly benefits from
756-
// being close to enemies
757-
if (velocityToEnemy > 0.1f && (drone->weaponInfo->type == SHOTGUN_WEAPON || enemyDistance > DISTANCE_CUTOFF)) {
758-
reward += APPROACH_REWARD;
759-
}
776+
// const b2Vec2 enemyDirection = b2Normalize(b2Sub(enemyDrone->pos, drone->pos));
777+
// const float velocityToEnemy = b2Dot(drone->lastVelocity, enemyDirection);
778+
// const float enemyDistance = b2Distance(enemyDrone->pos, drone->pos);
779+
// // stop rewarding approaching an enemy if they're very close
780+
// // to avoid constant clashing; always reward approaching when
781+
// // the current weapon is the shotgun, it greatly benefits from
782+
// // being close to enemies
783+
// if (velocityToEnemy > 0.1f && (drone->weaponInfo->type == SHOTGUN_WEAPON || enemyDistance > DISTANCE_CUTOFF)) {
784+
// reward += APPROACH_REWARD;
785+
// }
760786
}
761787

762788
return reward;
@@ -766,21 +792,19 @@ const float REWARD_EPS = 1.0e-6f;
766792

767793
void computeRewards(iwEnv *e, const bool roundOver, const int8_t winner, const int8_t winningTeam) {
768794
if (roundOver && winner != -1 && winner < e->numAgents) {
769-
e->rewards[winner] += WIN_REWARD;
795+
e->rewards[winner] += e->winReward;
770796
}
771797

772798
for (uint8_t i = 0; i < e->numDrones; i++) {
773799
float reward = 0.0f;
774800
droneEntity *drone = safe_array_get_at(e->drones, i);
775-
if (!drone->dead) {
776-
reward = computeReward(e, drone);
777-
if (roundOver && winningTeam == drone->team) {
778-
reward += WIN_REWARD;
779-
}
801+
reward = computeReward(e, drone);
802+
if (!drone->dead && roundOver && winningTeam == drone->team) {
803+
reward += e->winReward;
780804
} else if (drone->diedThisStep) {
781-
reward = DEATH_PUNISHMENT;
805+
reward = e->deathPunishment;
782806
if (drone->killedBy == drone->idx) {
783-
reward += SELF_KILL_PUNISHMENT;
807+
reward += e->selfKillPunishment;
784808
}
785809
}
786810
if (i < e->numAgents) {
@@ -985,6 +1009,11 @@ void addLog(iwEnv *e, Log *log) {
9851009
e->log.stats[j].totalBursts += log->stats[j].totalBursts;
9861010
e->log.stats[j].burstsHit += log->stats[j].burstsHit;
9871011
e->log.stats[j].energyEmptied += log->stats[j].energyEmptied;
1012+
e->log.stats[j].shieldsBroken += log->stats[j].shieldsBroken;
1013+
e->log.stats[j].ownShieldBroken += log->stats[j].ownShieldBroken;
1014+
e->log.stats[j].selfKills += log->stats[j].selfKills;
1015+
e->log.stats[j].kills += log->stats[j].kills;
1016+
e->log.stats[j].unknownKills += log->stats[j].unknownKills;
9881017

9891018
for (uint8_t k = 0; k < NUM_WEAPONS; k++) {
9901019
e->log.stats[j].shotsFired[k] += log->stats[j].shotsFired[k];
@@ -1006,6 +1035,7 @@ void addLog(iwEnv *e, Log *log) {
10061035
e->log.n += 1.0f;
10071036
}
10081037

1038+
// TODO: 2nd agent doesn't seem to work right
10091039
void stepEnv(iwEnv *e) {
10101040
if (e->needsReset) {
10111041
DEBUG_LOG("Resetting environment");
@@ -1017,6 +1047,14 @@ void stepEnv(iwEnv *e) {
10171047
#endif
10181048
}
10191049

1050+
#ifndef NDEBUG
1051+
for (uint8_t i = 0; i < cc_array_size(e->debugPoints); i++) {
1052+
debugPoint *point = safe_array_get_at(e->debugPoints, i);
1053+
fastFree(point);
1054+
}
1055+
cc_array_remove_all(e->debugPoints);
1056+
#endif
1057+
10201058
agentActions stepActions[e->numDrones];
10211059
memset(stepActions, 0x0, e->numDrones * sizeof(agentActions));
10221060

0 commit comments

Comments
 (0)