@@ -367,10 +367,10 @@ void computeObs(iwEnv *e) {
367367 continue ;
368368 }
369369
370- if (agentDrone -> stepInfo .shotHit [i ]) {
370+ if (agentDrone -> stepInfo .shotHit [i ] != 0.0f ) {
371371 hitShot = true;
372372 }
373- if (agentDrone -> stepInfo .shotTaken [i ]) {
373+ if (agentDrone -> stepInfo .shotTaken [i ] != 0.0f ) {
374374 tookShot = true;
375375 }
376376
@@ -539,6 +539,19 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
539539 e -> sittingDuck = sittingDuck ;
540540 e -> isTraining = isTraining ;
541541
542+ e -> winReward = WIN_REWARD ;
543+ e -> selfKillPunishment = SELF_KILL_PUNISHMENT ;
544+ e -> enemyDeathReward = ENEMY_DEATH_REWARD ;
545+ e -> enemyKillReward = ENEMY_KILL_REWARD ;
546+ e -> teammateDeathPunishment = TEAMMATE_DEATH_PUNISHMENT ;
547+ e -> teammateKillPunishment = TEAMMATE_KILL_PUNISHMENT ;
548+ e -> deathPunishment = DEATH_PUNISHMENT ;
549+ e -> energyEmptiedPunishment = ENERGY_EMPTY_PUNISHMENT ;
550+ e -> weaponPickupReward = WEAPON_PICKUP_REWARD ;
551+ e -> shieldBreakReward = SHIELD_BREAK_REWARD ;
552+ e -> shotHitRewardCoef = SHOT_HIT_REWARD_COEF ;
553+ e -> explosionHitRewardCoef = EXPLOSION_HIT_REWARD_COEF ;
554+
542555 e -> obsBytes = obsBytes (e -> numDrones );
543556 e -> discreteObsBytes = alignedSize (discreteObsSize (e -> numDrones ) * sizeof (uint8_t ), sizeof (float ));
544557
@@ -583,9 +596,28 @@ iwEnv *initEnv(iwEnv *e, uint8_t numDrones, uint8_t numAgents, int8_t mapIdx, ui
583596 e -> humanDroneInput = 0 ;
584597 e -> connectedControllers = 0 ;
585598
599+ #ifndef NDEBUG
600+ create_array (& e -> debugPoints , 4 );
601+ #endif
602+
586603 return e ;
587604}
588605
606+ void setRewards (iwEnv * e , float winReward , float selfKillPunishment , float enemyDeathReward , float enemyKillReward , float teammateDeathPunishment , float teammateKillPunishment , float deathPunishment , float energyEmptiedPunishment , float weaponPickupReward , float shieldBreakReward , float shotHitRewardCoef , float explosionHitRewardCoef ) {
607+ e -> winReward = winReward ;
608+ e -> selfKillPunishment = selfKillPunishment ;
609+ e -> enemyDeathReward = enemyDeathReward ;
610+ e -> enemyKillReward = enemyKillReward ;
611+ e -> teammateDeathPunishment = teammateDeathPunishment ;
612+ e -> teammateKillPunishment = teammateKillPunishment ;
613+ e -> deathPunishment = deathPunishment ;
614+ e -> energyEmptiedPunishment = energyEmptiedPunishment ;
615+ e -> weaponPickupReward = weaponPickupReward ;
616+ e -> shieldBreakReward = shieldBreakReward ;
617+ e -> shotHitRewardCoef = shotHitRewardCoef ;
618+ e -> explosionHitRewardCoef = explosionHitRewardCoef ;
619+ }
620+
589621void clearEnv (iwEnv * e ) {
590622 // rewards get cleared in stepEnv every step
591623 // memset(e->masks, 1, e->numAgents * sizeof(uint8_t));
@@ -673,36 +705,30 @@ void destroyEnv(iwEnv *e) {
673705 cc_array_destroy (e -> dronePieces );
674706
675707 b2DestroyWorld (e -> worldID );
708+
709+ #ifndef NDEBUG
710+ cc_array_destroy (e -> debugPoints );
711+ #endif
676712}
677713
678714void resetEnv (iwEnv * e ) {
679715 clearEnv (e );
680716 setupEnv (e );
681717}
682718
683- float computeShotReward (const droneEntity * drone , const weaponInformation * weaponInfo ) {
684- const float weaponForce = weaponInfo -> fireMagnitude * weaponInfo -> invMass ;
685- const float scaledForce = (weaponForce * (weaponForce * SHOT_HIT_REWARD_COEF )) + 0.25f ;
686- return scaledForce + computeHitStrength (drone );
687- }
688-
689- float computeExplosionReward (const droneEntity * drone ) {
690- return computeHitStrength (drone ) * EXPLOSION_HIT_REWARD_COEF ;
691- }
692-
693719float computeReward (iwEnv * e , droneEntity * drone ) {
694720 float reward = 0.0f ;
695721
696722 if (drone -> energyFullyDepleted && drone -> energyRefillWait == DRONE_ENERGY_REFILL_EMPTY_WAIT ) {
697- reward += ENERGY_EMPTY_PUNISHMENT ;
723+ reward += e -> energyEmptiedPunishment ;
698724 }
699725
700726 // only reward picking up a weapon if the standard weapon was
701727 // previously held; every weapon is better than the standard
702728 // weapon, but other weapons are situational better so don't
703729 // reward switching a non-standard weapon
704730 if (drone -> stepInfo .pickedUpWeapon && drone -> stepInfo .prevWeapon == STANDARD_WEAPON ) {
705- reward += WEAPON_PICKUP_REWARD ;
731+ reward += e -> weaponPickupReward ;
706732 }
707733
708734 for (uint8_t i = 0 ; i < e -> numDrones ; i ++ ) {
@@ -712,51 +738,51 @@ float computeReward(iwEnv *e, droneEntity *drone) {
712738 droneEntity * enemyDrone = safe_array_get_at (e -> drones , i );
713739 const bool onTeam = drone -> team == enemyDrone -> team ;
714740
715- if (drone -> stepInfo .shotHit [i ] != 0 && !onTeam ) {
716- // subtract 1 from the weapon type because 1 is added so we
717- // can use 0 as no shot was hit
718- const weaponInformation * weaponInfo = weaponInfos [drone -> stepInfo .shotHit [i ] - 1 ];
719- reward += computeShotReward (enemyDrone , weaponInfo );
741+ // TODO: punish for hitting teammates?
742+ if (drone -> stepInfo .shotHit [i ] != 0.0f && !onTeam ) {
743+ reward += drone -> stepInfo .shotHit [i ] * e -> shotHitRewardCoef ;
720744 }
721- if (drone -> stepInfo .explosionHit [i ] && !onTeam ) {
722- reward += computeExplosionReward (enemyDrone );
745+ if (drone -> stepInfo .explosionHit [i ] != 0.0f && !onTeam ) {
746+ reward += drone -> stepInfo .explosionHit [i ] * e -> explosionHitRewardCoef ;
747+ }
748+ if (drone -> stepInfo .brokeShield [i ] && !onTeam ) {
749+ reward += e -> shieldBreakReward ;
723750 }
724751
725752 if (e -> numAgents == e -> numDrones ) {
726- if (drone -> stepInfo .shotTaken [i ] != 0 && !onTeam ) {
727- const weaponInformation * weaponInfo = weaponInfos [drone -> stepInfo .shotTaken [i ] - 1 ];
728- reward -= computeShotReward (drone , weaponInfo ) * 0.5f ;
753+ if (drone -> stepInfo .shotTaken [i ] != 0 ) {
754+ reward -= drone -> stepInfo .shotTaken [i ] * e -> shotHitRewardCoef ;
729755 }
730- if (drone -> stepInfo .explosionTaken [i ] && ! onTeam ) {
731- reward -= computeExplosionReward ( drone ) * 0.5f ;
756+ if (drone -> stepInfo .explosionTaken [i ]) {
757+ reward -= drone -> stepInfo . explosionTaken [ i ] * e -> explosionHitRewardCoef ;
732758 }
733759 }
734760
735761 if (enemyDrone -> dead && enemyDrone -> diedThisStep ) {
736762 if (!onTeam ) {
737- reward += ENEMY_DEATH_REWARD ;
763+ reward += e -> enemyDeathReward ;
738764 if (drone -> killed [i ]) {
739- reward += ENEMY_KILL_REWARD ;
765+ reward += e -> enemyKillReward ;
740766 }
741767 } else {
742- reward += TEAMMATE_DEATH_PUNISHMENT ;
768+ reward += e -> teammateDeathPunishment ;
743769 if (drone -> killed [i ]) {
744- reward += TEAMMATE_KILL_PUNISHMENT ;
770+ reward += e -> teammateKillPunishment ;
745771 }
746772 }
747773 continue ;
748774 }
749775
750- const b2Vec2 enemyDirection = b2Normalize (b2Sub (enemyDrone -> pos , drone -> pos ));
751- const float velocityToEnemy = b2Dot (drone -> lastVelocity , enemyDirection );
752- const float enemyDistance = b2Distance (enemyDrone -> pos , drone -> pos );
753- // stop rewarding approaching an enemy if they're very close
754- // to avoid constant clashing; always reward approaching when
755- // the current weapon is the shotgun, it greatly benefits from
756- // being close to enemies
757- if (velocityToEnemy > 0.1f && (drone -> weaponInfo -> type == SHOTGUN_WEAPON || enemyDistance > DISTANCE_CUTOFF )) {
758- reward += APPROACH_REWARD ;
759- }
776+ // const b2Vec2 enemyDirection = b2Normalize(b2Sub(enemyDrone->pos, drone->pos));
777+ // const float velocityToEnemy = b2Dot(drone->lastVelocity, enemyDirection);
778+ // const float enemyDistance = b2Distance(enemyDrone->pos, drone->pos);
779+ // // stop rewarding approaching an enemy if they're very close
780+ // // to avoid constant clashing; always reward approaching when
781+ // // the current weapon is the shotgun, it greatly benefits from
782+ // // being close to enemies
783+ // if (velocityToEnemy > 0.1f && (drone->weaponInfo->type == SHOTGUN_WEAPON || enemyDistance > DISTANCE_CUTOFF)) {
784+ // reward += APPROACH_REWARD;
785+ // }
760786 }
761787
762788 return reward ;
@@ -766,21 +792,19 @@ const float REWARD_EPS = 1.0e-6f;
766792
767793void computeRewards (iwEnv * e , const bool roundOver , const int8_t winner , const int8_t winningTeam ) {
768794 if (roundOver && winner != -1 && winner < e -> numAgents ) {
769- e -> rewards [winner ] += WIN_REWARD ;
795+ e -> rewards [winner ] += e -> winReward ;
770796 }
771797
772798 for (uint8_t i = 0 ; i < e -> numDrones ; i ++ ) {
773799 float reward = 0.0f ;
774800 droneEntity * drone = safe_array_get_at (e -> drones , i );
775- if (!drone -> dead ) {
776- reward = computeReward (e , drone );
777- if (roundOver && winningTeam == drone -> team ) {
778- reward += WIN_REWARD ;
779- }
801+ reward = computeReward (e , drone );
802+ if (!drone -> dead && roundOver && winningTeam == drone -> team ) {
803+ reward += e -> winReward ;
780804 } else if (drone -> diedThisStep ) {
781- reward = DEATH_PUNISHMENT ;
805+ reward = e -> deathPunishment ;
782806 if (drone -> killedBy == drone -> idx ) {
783- reward += SELF_KILL_PUNISHMENT ;
807+ reward += e -> selfKillPunishment ;
784808 }
785809 }
786810 if (i < e -> numAgents ) {
@@ -985,6 +1009,11 @@ void addLog(iwEnv *e, Log *log) {
9851009 e -> log .stats [j ].totalBursts += log -> stats [j ].totalBursts ;
9861010 e -> log .stats [j ].burstsHit += log -> stats [j ].burstsHit ;
9871011 e -> log .stats [j ].energyEmptied += log -> stats [j ].energyEmptied ;
1012+ e -> log .stats [j ].shieldsBroken += log -> stats [j ].shieldsBroken ;
1013+ e -> log .stats [j ].ownShieldBroken += log -> stats [j ].ownShieldBroken ;
1014+ e -> log .stats [j ].selfKills += log -> stats [j ].selfKills ;
1015+ e -> log .stats [j ].kills += log -> stats [j ].kills ;
1016+ e -> log .stats [j ].unknownKills += log -> stats [j ].unknownKills ;
9881017
9891018 for (uint8_t k = 0 ; k < NUM_WEAPONS ; k ++ ) {
9901019 e -> log .stats [j ].shotsFired [k ] += log -> stats [j ].shotsFired [k ];
@@ -1006,6 +1035,7 @@ void addLog(iwEnv *e, Log *log) {
10061035 e -> log .n += 1.0f ;
10071036}
10081037
1038+ // TODO: 2nd agent doesn't seem to work right
10091039void stepEnv (iwEnv * e ) {
10101040 if (e -> needsReset ) {
10111041 DEBUG_LOG ("Resetting environment" );
@@ -1017,6 +1047,14 @@ void stepEnv(iwEnv *e) {
10171047#endif
10181048 }
10191049
1050+ #ifndef NDEBUG
1051+ for (uint8_t i = 0 ; i < cc_array_size (e -> debugPoints ); i ++ ) {
1052+ debugPoint * point = safe_array_get_at (e -> debugPoints , i );
1053+ fastFree (point );
1054+ }
1055+ cc_array_remove_all (e -> debugPoints );
1056+ #endif
1057+
10201058 agentActions stepActions [e -> numDrones ];
10211059 memset (stepActions , 0x0 , e -> numDrones * sizeof (agentActions ));
10221060
0 commit comments