66[ RequireComponent ( typeof ( JointDriveController ) ) ] // Required to set joint forces
77public class WormAgent : Agent
88{
9- [ Header ( "Target To Walk Towards" ) ]
10- [ Space ( 10 ) ]
9+ [ Header ( "Target To Walk Towards" ) ] [ Space ( 10 ) ]
1110 public Transform target ;
12-
1311 public Transform ground ;
14- public bool detectTargets ;
15- public bool targetIsStatic ;
16- public bool respawnTargetWhenTouched ;
17- public float targetSpawnRadius ;
1812
19- [ Header ( "Body Parts" ) ] [ Space ( 10 ) ]
20- public Transform bodySegment0 ;
13+ [ Header ( "Body Parts" ) ] [ Space ( 10 ) ] public Transform bodySegment0 ;
2114 public Transform bodySegment1 ;
2215 public Transform bodySegment2 ;
2316 public Transform bodySegment3 ;
2417
25- [ Header ( "Joint Settings" ) ] [ Space ( 10 ) ]
26- JointDriveController m_JdController ;
18+ [ Header ( "Joint Settings" ) ] [ Space ( 10 ) ] JointDriveController m_JdController ;
2719 Vector3 m_DirToTarget ;
2820 float m_MovingTowardsDot ;
2921 float m_FacingDot ;
3022
31- [ Header ( "Reward Functions To Use" ) ]
32- [ Space ( 10 ) ]
23+ [ Header ( "Reward Functions To Use" ) ] [ Space ( 10 ) ]
3324 public bool rewardMovingTowardsTarget ; // Agent should move towards target
3425
3526 public bool rewardFacingTarget ; // Agent should face the target
@@ -50,21 +41,14 @@ public override void Initialize()
5041 m_JdController . SetupBodyPart ( bodySegment1 ) ;
5142 m_JdController . SetupBodyPart ( bodySegment2 ) ;
5243 m_JdController . SetupBodyPart ( bodySegment3 ) ;
53-
54- //We only want the head to detect the target
55- //So we need to remove TargetContact from everything else
56- //This is a temp fix till we can redesign
57- DestroyImmediate ( bodySegment1 . GetComponent < TargetContact > ( ) ) ;
58- DestroyImmediate ( bodySegment2 . GetComponent < TargetContact > ( ) ) ;
59- DestroyImmediate ( bodySegment3 . GetComponent < TargetContact > ( ) ) ;
6044 }
6145
6246
6347 //Get Joint Rotation Relative to the Connected Rigidbody
6448 //We want to collect this info because it is the actual rotation, not the "target rotation"
6549 public Quaternion GetJointRotation ( ConfigurableJoint joint )
6650 {
67- return ( Quaternion . FromToRotation ( joint . axis , joint . connectedBody . transform . rotation . eulerAngles ) ) ;
51+ return ( Quaternion . FromToRotation ( joint . axis , joint . connectedBody . transform . rotation . eulerAngles ) ) ;
6852 }
6953
7054 /// <summary>
@@ -78,7 +62,8 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
7862 var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix . inverse . MultiplyVector ( rb . velocity ) ;
7963 sensor . AddObservation ( velocityRelativeToLookRotationToTarget ) ;
8064
81- var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix . inverse . MultiplyVector ( rb . angularVelocity ) ;
65+ var angularVelocityRelativeToLookRotationToTarget =
66+ m_TargetDirMatrix . inverse . MultiplyVector ( rb . angularVelocity ) ;
8267 sensor . AddObservation ( angularVelocityRelativeToLookRotationToTarget ) ;
8368
8469 if ( bp . rb . transform != bodySegment0 )
@@ -103,18 +88,19 @@ public override void CollectObservations(VectorSensor sensor)
10388 float maxDist = 10 ;
10489 if ( Physics . Raycast ( bodySegment0 . position , Vector3 . down , out hit , maxDist ) )
10590 {
106- sensor . AddObservation ( hit . distance / maxDist ) ;
91+ sensor . AddObservation ( hit . distance / maxDist ) ;
10792 }
10893 else
10994 sensor . AddObservation ( 1 ) ;
11095
111- foreach ( var bodyPart in m_JdController . bodyPartsDict . Values )
96+ foreach ( var bodyPart in m_JdController . bodyPartsList )
11297 {
11398 CollectObservationBodyPart ( bodyPart , sensor ) ;
11499 }
115100
116101 //Rotation delta between the matrix and the head
117- Quaternion headRotationDeltaFromMatrixRot = Quaternion . Inverse ( m_TargetDirMatrix . rotation ) * bodySegment0 . rotation ;
102+ Quaternion headRotationDeltaFromMatrixRot =
103+ Quaternion . Inverse ( m_TargetDirMatrix . rotation ) * bodySegment0 . rotation ;
118104 sensor . AddObservation ( headRotationDeltaFromMatrixRot ) ;
119105 }
120106
@@ -124,20 +110,6 @@ public override void CollectObservations(VectorSensor sensor)
124110 public void TouchedTarget ( )
125111 {
126112 AddReward ( 1f ) ;
127- if ( respawnTargetWhenTouched )
128- {
129- GetRandomTargetPos ( ) ;
130- }
131- }
132-
133- /// <summary>
134- /// Moves target to a random position within specified radius.
135- /// </summary>
136- public void GetRandomTargetPos ( )
137- {
138- var newTargetPos = Random . insideUnitSphere * targetSpawnRadius ;
139- newTargetPos . y = 5 ;
140- target . position = newTargetPos + ground . position ;
141113 }
142114
143115 public override void OnActionReceived ( float [ ] vectorAction )
@@ -156,25 +128,15 @@ public override void OnActionReceived(float[] vectorAction)
156128 bpDict [ bodySegment2 ] . SetJointStrength ( vectorAction [ ++ i ] ) ;
157129 bpDict [ bodySegment3 ] . SetJointStrength ( vectorAction [ ++ i ] ) ;
158130
159- if ( bodySegment0 . position . y < ground . position . y - 2 )
131+ // Detect if worm fell off/through platform
132+ if ( bodySegment0 . position . y < ground . position . y - 2 )
160133 {
161134 EndEpisode ( ) ;
162135 }
163136 }
164137
165138 void FixedUpdate ( )
166139 {
167- if ( detectTargets )
168- {
169- foreach ( var bodyPart in m_JdController . bodyPartsDict . Values )
170- {
171- if ( bodyPart . targetContact && bodyPart . targetContact . touchingTarget )
172- {
173- TouchedTarget ( ) ;
174- }
175- }
176- }
177-
178140 // Set reward for this step according to mixture of the following elements.
179141 if ( rewardMovingTowardsTarget )
180142 {
@@ -197,7 +159,8 @@ void FixedUpdate()
197159 /// </summary>
198160 void RewardFunctionMovingTowards ( )
199161 {
200- m_MovingTowardsDot = Vector3 . Dot ( m_JdController . bodyPartsDict [ bodySegment0 ] . rb . velocity , m_DirToTarget . normalized ) ;
162+ m_MovingTowardsDot =
163+ Vector3 . Dot ( m_JdController . bodyPartsDict [ bodySegment0 ] . rb . velocity , m_DirToTarget . normalized ) ;
201164 AddReward ( 0.01f * m_MovingTowardsDot ) ;
202165 }
203166
@@ -211,7 +174,7 @@ void RewardFunctionFacingTarget()
211174 }
212175
213176 /// <summary>
214- /// Existential penalty for time-contrained tasks.
177+ /// Existential penalty for time-constrained tasks.
215178 /// </summary>
216179 void RewardFunctionTimePenalty ( )
217180 {
@@ -227,15 +190,12 @@ public override void OnEpisodeBegin()
227190 {
228191 bodyPart . Reset ( bodyPart ) ;
229192 }
193+
230194 if ( m_DirToTarget != Vector3 . zero )
231195 {
232196 transform . rotation = Quaternion . LookRotation ( m_DirToTarget ) ;
233197 }
234- transform . Rotate ( Vector3 . up , Random . Range ( 0.0f , 360.0f ) ) ;
235198
236- if ( ! targetIsStatic )
237- {
238- GetRandomTargetPos ( ) ;
239- }
199+ transform . Rotate ( Vector3 . up , Random . Range ( 0.0f , 360.0f ) ) ;
240200 }
241201}
0 commit comments