@@ -37,7 +37,9 @@ internal void SetPolicy(IPolicy policy)
3737
3838 public int initializeAgentCalls ;
3939 public int collectObservationsCalls ;
40+ public int collectObservationsCallsSinceLastReset ;
4041 public int agentActionCalls ;
42+ public int agentActionCallsSinceLastReset ;
4143 public int agentResetCalls ;
4244 public override void InitializeAgent ( )
4345 {
@@ -54,18 +56,22 @@ public override void InitializeAgent()
5456 public override void CollectObservations ( )
5557 {
5658 collectObservationsCalls += 1 ;
59+ collectObservationsCallsSinceLastReset += 1 ;
5760 AddVectorObs ( 0f ) ;
5861 }
5962
6063 public override void AgentAction ( float [ ] vectorAction )
6164 {
6265 agentActionCalls += 1 ;
66+ agentActionCallsSinceLastReset += 1 ;
6367 AddReward ( 0.1f ) ;
6468 }
6569
6670 public override void AgentReset ( )
6771 {
6872 agentResetCalls += 1 ;
73+ collectObservationsCallsSinceLastReset = 0 ;
74+ agentActionCallsSinceLastReset = 0 ;
6975 }
7076
7177 public override float [ ] Heuristic ( )
@@ -484,7 +490,7 @@ public void TestCumulativeReward()
484490 var j = 0 ;
485491 for ( var i = 0 ; i < 500 ; i ++ )
486492 {
487- if ( i % 20 == 0 )
493+ if ( i % 21 == 0 )
488494 {
489495 j = 0 ;
490496 }
@@ -500,5 +506,40 @@ public void TestCumulativeReward()
500506 aca . EnvironmentStep ( ) ;
501507 }
502508 }
509+
510+ [ Test ]
511+ public void TestMaxStepsReset ( )
512+ {
513+ var agentGo1 = new GameObject ( "TestAgent" ) ;
514+ agentGo1 . AddComponent < TestAgent > ( ) ;
515+ var agent1 = agentGo1 . GetComponent < TestAgent > ( ) ;
516+ var aca = Academy . Instance ;
517+
518+ var decisionRequester = agent1 . gameObject . AddComponent < DecisionRequester > ( ) ;
519+ decisionRequester . DecisionPeriod = 1 ;
520+ decisionRequester . Awake ( ) ;
521+
522+ var maxStep = 6 ;
523+ agent1 . maxStep = maxStep ;
524+ agent1 . LazyInitialize ( ) ;
525+
526+ for ( var i = 0 ; i < 15 ; i ++ )
527+ {
528+ // We expect resets to occur when there are maxSteps actions since the last reset (and on the first step)
529+ var expectReset = agent1 . agentActionCallsSinceLastReset == maxStep || ( i == 0 ) ;
530+ var previousNumResets = agent1 . agentResetCalls ;
531+
532+ aca . EnvironmentStep ( ) ;
533+
534+ if ( expectReset )
535+ {
536+ Assert . AreEqual ( previousNumResets + 1 , agent1 . agentResetCalls ) ;
537+ }
538+ else
539+ {
540+ Assert . AreEqual ( previousNumResets , agent1 . agentResetCalls ) ;
541+ }
542+ }
543+ }
503544 }
504545}
0 commit comments