@@ -539,6 +539,19 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
539539 e -> sittingDuck = sittingDuck ;
540540 e -> isTraining = isTraining ;
541541
542+ e -> winReward = WIN_REWARD ;
543+ e -> selfKillPunishment = SELF_KILL_PUNISHMENT ;
544+ e -> enemyDeathReward = ENEMY_DEATH_REWARD ;
545+ e -> enemyKillReward = ENEMY_KILL_REWARD ;
546+ e -> teammateDeathPunishment = TEAMMATE_DEATH_PUNISHMENT ;
547+ e -> teammateKillPunishment = TEAMMATE_KILL_PUNISHMENT ;
548+ e -> deathPunishment = DEATH_PUNISHMENT ;
549+ e -> energyEmptiedPunishment = ENERGY_EMPTY_PUNISHMENT ;
550+ e -> weaponPickupReward = WEAPON_PICKUP_REWARD ;
551+ e -> shieldBreakReward = SHIELD_BREAK_REWARD ;
552+ e -> shotHitRewardCoef = SHOT_HIT_REWARD_COEF ;
553+ e -> explosionHitRewardCoef = EXPLOSION_HIT_REWARD_COEF ;
554+
542555 e -> obsBytes = obsBytes (e -> numDrones );
543556 e -> discreteObsBytes = alignedSize (discreteObsSize (e -> numDrones ) * sizeof (uint8_t ), sizeof (float ));
544557
@@ -586,6 +599,21 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
586599 return e ;
587600}
588601
602+ void setRewards (iwEnv * e , float winReward , float selfKillPunishment , float enemyDeathReward , float enemyKillReward , float teammateDeathPunishment , float teammateKillPunishment , float deathPunishment , float energyEmptiedPunishment , float weaponPickupReward , float shieldBreakReward , float shotHitRewardCoef , float explosionHitRewardCoef ) {
603+ e -> winReward = winReward ;
604+ e -> selfKillPunishment = selfKillPunishment ;
605+ e -> enemyDeathReward = enemyDeathReward ;
606+ e -> enemyKillReward = enemyKillReward ;
607+ e -> teammateDeathPunishment = teammateDeathPunishment ;
608+ e -> teammateKillPunishment = teammateKillPunishment ;
609+ e -> deathPunishment = deathPunishment ;
610+ e -> energyEmptiedPunishment = energyEmptiedPunishment ;
611+ e -> weaponPickupReward = weaponPickupReward ;
612+ e -> shieldBreakReward = shieldBreakReward ;
613+ e -> shotHitRewardCoef = shotHitRewardCoef ;
614+ e -> explosionHitRewardCoef = explosionHitRewardCoef ;
615+ }
616+
589617void clearEnv (iwEnv * e ) {
590618 // rewards get cleared in stepEnv every step
591619 // memset(e->masks, 1, e->numAgents * sizeof(uint8_t));
@@ -684,15 +712,15 @@ float computeReward(iwEnv *e, droneEntity *drone) {
684712 float reward = 0.0f ;
685713
686714 if (drone -> energyFullyDepleted && drone -> energyRefillWait == DRONE_ENERGY_REFILL_EMPTY_WAIT ) {
687- reward += ENERGY_EMPTY_PUNISHMENT ;
715+ reward += e -> energyEmptiedPunishment ;
688716 }
689717
690718 // only reward picking up a weapon if the standard weapon was
691719 // previously held; every weapon is better than the standard
692720 // weapon, but other weapons are situational better so don't
693721 // reward switching a non-standard weapon
694722 if (drone -> stepInfo .pickedUpWeapon && drone -> stepInfo .prevWeapon == STANDARD_WEAPON ) {
695- reward += WEAPON_PICKUP_REWARD ;
723+ reward += e -> weaponPickupReward ;
696724 }
697725
698726 for (uint8_t i = 0 ; i < e -> numDrones ; i ++ ) {
@@ -704,34 +732,34 @@ float computeReward(iwEnv *e, droneEntity *drone) {
704732
705733 // TODO: punish for hitting teammates?
706734 if (drone -> stepInfo .shotHit [i ] != 0.0f && !onTeam ) {
707- reward += drone -> stepInfo .shotHit [i ] * SHOT_HIT_REWARD_COEF ;
735+ reward += drone -> stepInfo .shotHit [i ] * e -> shotHitRewardCoef ;
708736 }
709737 if (drone -> stepInfo .explosionHit [i ] != 0.0f && !onTeam ) {
710- reward += drone -> stepInfo .explosionHit [i ] * EXPLOSION_HIT_REWARD_COEF ;
738+ reward += drone -> stepInfo .explosionHit [i ] * e -> explosionHitRewardCoef ;
711739 }
712740 if (drone -> stepInfo .brokeShield [i ] && !onTeam ) {
713- reward += SHIELD_BREAK_REWARD ;
741+ reward += e -> shieldBreakReward ;
714742 }
715743
716744 if (e -> numAgents == e -> numDrones ) {
717745 if (drone -> stepInfo .shotTaken [i ] != 0 ) {
718- reward -= drone -> stepInfo .shotTaken [i ] * SHOT_HIT_REWARD_COEF ;
746+ reward -= drone -> stepInfo .shotTaken [i ] * e -> shotHitRewardCoef ;
719747 }
720748 if (drone -> stepInfo .explosionTaken [i ]) {
721- reward -= drone -> stepInfo .explosionTaken [i ] * EXPLOSION_HIT_REWARD_COEF ;
749+ reward -= drone -> stepInfo .explosionTaken [i ] * e -> explosionHitRewardCoef ;
722750 }
723751 }
724752
725753 if (enemyDrone -> dead && enemyDrone -> diedThisStep ) {
726754 if (!onTeam ) {
727- reward += ENEMY_DEATH_REWARD ;
755+ reward += e -> enemyDeathReward ;
728756 if (drone -> killed [i ]) {
729- reward += ENEMY_KILL_REWARD ;
757+ reward += e -> enemyKillReward ;
730758 }
731759 } else {
732- reward += TEAMMATE_DEATH_PUNISHMENT ;
760+ reward += e -> teammateDeathPunishment ;
733761 if (drone -> killed [i ]) {
734- reward += TEAMMATE_KILL_PUNISHMENT ;
762+ reward += e -> teammateKillPunishment ;
735763 }
736764 }
737765 continue ;
@@ -756,19 +784,19 @@ const float REWARD_EPS = 1.0e-6f;
756784
757785void computeRewards (iwEnv * e , const bool roundOver , const int8_t winner , const int8_t winningTeam ) {
758786 if (roundOver && winner != -1 && winner < e -> numAgents ) {
759- e -> rewards [winner ] += WIN_REWARD ;
787+ e -> rewards [winner ] += e -> winReward ;
760788 }
761789
762790 for (uint8_t i = 0 ; i < e -> numDrones ; i ++ ) {
763791 float reward = 0.0f ;
764792 droneEntity * drone = safe_array_get_at (e -> drones , i );
765793 reward = computeReward (e , drone );
766794 if (!drone -> dead && roundOver && winningTeam == drone -> team ) {
767- reward += WIN_REWARD ;
795+ reward += e -> winReward ;
768796 } else if (drone -> diedThisStep ) {
769- reward = DEATH_PUNISHMENT ;
797+ reward = e -> deathPunishment ;
770798 if (drone -> killedBy == drone -> idx ) {
771- reward += SELF_KILL_PUNISHMENT ;
799+ reward += e -> selfKillPunishment ;
772800 }
773801 }
774802 if (i < e -> numAgents ) {
@@ -973,6 +1001,10 @@ void addLog(iwEnv *e, Log *log) {
9731001 e -> log .stats [j ].totalBursts += log -> stats [j ].totalBursts ;
9741002 e -> log .stats [j ].burstsHit += log -> stats [j ].burstsHit ;
9751003 e -> log .stats [j ].energyEmptied += log -> stats [j ].energyEmptied ;
1004+ e -> log .stats [j ].shieldsBroken += log -> stats [j ].shieldsBroken ;
1005+ e -> log .stats [j ].ownShieldBroken += log -> stats [j ].ownShieldBroken ;
1006+ e -> log .stats [j ].selfKills += log -> stats [j ].selfKills ;
1007+ e -> log .stats [j ].kills += log -> stats [j ].kills ;
9761008
9771009 for (uint8_t k = 0 ; k < NUM_WEAPONS ; k ++ ) {
9781010 e -> log .stats [j ].shotsFired [k ] += log -> stats [j ].shotsFired [k ];
0 commit comments