Skip to content

Commit 48b2dec

Browse files
Fix off-by-one error on AgentReset and maxSteps (#3394)
* Fix ballance ball 100 reward * Re-test * Add test for maxSteps and number of AgentActions Co-authored-by: Chris Elion <[email protected]>
1 parent 566770f commit 48b2dec

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

com.unity.ml-agents/Runtime/Agent.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ void SendInfo()
753753
/// Used by the brain to make the agent perform a step.
754754
void AgentStep()
755755
{
756-
if ((m_StepCount >= maxStep - 1) && (maxStep > 0))
756+
if ((m_StepCount >= maxStep) && (maxStep > 0))
757757
{
758758
NotifyAgentDone(true);
759759
_AgentReset();
@@ -762,6 +762,7 @@ void AgentStep()
762762
{
763763
m_StepCount += 1;
764764
}
765+
765766
if ((m_RequestAction) && (m_Brain != null))
766767
{
767768
m_RequestAction = false;

com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ internal void SetPolicy(IPolicy policy)
3737

3838
public int initializeAgentCalls;
3939
public int collectObservationsCalls;
40+
public int collectObservationsCallsSinceLastReset;
4041
public int agentActionCalls;
42+
public int agentActionCallsSinceLastReset;
4143
public int agentResetCalls;
4244
public override void InitializeAgent()
4345
{
@@ -54,18 +56,22 @@ public override void InitializeAgent()
5456
public override void CollectObservations()
5557
{
5658
collectObservationsCalls += 1;
59+
collectObservationsCallsSinceLastReset += 1;
5760
AddVectorObs(0f);
5861
}
5962

6063
public override void AgentAction(float[] vectorAction)
6164
{
6265
agentActionCalls += 1;
66+
agentActionCallsSinceLastReset += 1;
6367
AddReward(0.1f);
6468
}
6569

6670
public override void AgentReset()
6771
{
6872
agentResetCalls += 1;
73+
collectObservationsCallsSinceLastReset = 0;
74+
agentActionCallsSinceLastReset = 0;
6975
}
7076

7177
public override float[] Heuristic()
@@ -484,7 +490,7 @@ public void TestCumulativeReward()
484490
var j = 0;
485491
for (var i = 0; i < 500; i++)
486492
{
487-
if (i % 20 == 0)
493+
if (i % 21 == 0)
488494
{
489495
j = 0;
490496
}
@@ -500,5 +506,40 @@ public void TestCumulativeReward()
500506
aca.EnvironmentStep();
501507
}
502508
}
509+
510+
[Test]
511+
public void TestMaxStepsReset()
512+
{
513+
var agentGo1 = new GameObject("TestAgent");
514+
agentGo1.AddComponent<TestAgent>();
515+
var agent1 = agentGo1.GetComponent<TestAgent>();
516+
var aca = Academy.Instance;
517+
518+
var decisionRequester = agent1.gameObject.AddComponent<DecisionRequester>();
519+
decisionRequester.DecisionPeriod = 1;
520+
decisionRequester.Awake();
521+
522+
var maxStep = 6;
523+
agent1.maxStep = maxStep;
524+
agent1.LazyInitialize();
525+
526+
for (var i = 0; i < 15; i++)
527+
{
528+
// We expect resets to occur when there are maxSteps actions since the last reset (and on the first step)
529+
var expectReset = agent1.agentActionCallsSinceLastReset == maxStep || (i == 0);
530+
var previousNumResets = agent1.agentResetCalls;
531+
532+
aca.EnvironmentStep();
533+
534+
if (expectReset)
535+
{
536+
Assert.AreEqual(previousNumResets + 1, agent1.agentResetCalls);
537+
}
538+
else
539+
{
540+
Assert.AreEqual(previousNumResets, agent1.agentResetCalls);
541+
}
542+
}
543+
}
503544
}
504545
}

0 commit comments

Comments
 (0)