6
6
[ RequireComponent ( typeof ( JointDriveController ) ) ] // Required to set joint forces
7
7
public class WormAgent : Agent
8
8
{
9
- [ Header ( "Target To Walk Towards" ) ]
10
- [ Space ( 10 ) ]
9
+ [ Header ( "Target To Walk Towards" ) ] [ Space ( 10 ) ]
11
10
public Transform target ;
12
-
13
11
public Transform ground ;
14
- public bool detectTargets ;
15
- public bool targetIsStatic ;
16
- public bool respawnTargetWhenTouched ;
17
- public float targetSpawnRadius ;
18
12
19
- [ Header ( "Body Parts" ) ] [ Space ( 10 ) ]
20
- public Transform bodySegment0 ;
13
+ [ Header ( "Body Parts" ) ] [ Space ( 10 ) ] public Transform bodySegment0 ;
21
14
public Transform bodySegment1 ;
22
15
public Transform bodySegment2 ;
23
16
public Transform bodySegment3 ;
24
17
25
- [ Header ( "Joint Settings" ) ] [ Space ( 10 ) ]
26
- JointDriveController m_JdController ;
18
+ [ Header ( "Joint Settings" ) ] [ Space ( 10 ) ] JointDriveController m_JdController ;
27
19
Vector3 m_DirToTarget ;
28
20
float m_MovingTowardsDot ;
29
21
float m_FacingDot ;
30
22
31
- [ Header ( "Reward Functions To Use" ) ]
32
- [ Space ( 10 ) ]
23
+ [ Header ( "Reward Functions To Use" ) ] [ Space ( 10 ) ]
33
24
public bool rewardMovingTowardsTarget ; // Agent should move towards target
34
25
35
26
public bool rewardFacingTarget ; // Agent should face the target
@@ -50,21 +41,14 @@ public override void Initialize()
50
41
m_JdController . SetupBodyPart ( bodySegment1 ) ;
51
42
m_JdController . SetupBodyPart ( bodySegment2 ) ;
52
43
m_JdController . SetupBodyPart ( bodySegment3 ) ;
53
-
54
- //We only want the head to detect the target
55
- //So we need to remove TargetContact from everything else
56
- //This is a temp fix till we can redesign
57
- DestroyImmediate ( bodySegment1 . GetComponent < TargetContact > ( ) ) ;
58
- DestroyImmediate ( bodySegment2 . GetComponent < TargetContact > ( ) ) ;
59
- DestroyImmediate ( bodySegment3 . GetComponent < TargetContact > ( ) ) ;
60
44
}
61
45
62
46
63
47
//Get Joint Rotation Relative to the Connected Rigidbody
64
48
//We want to collect this info because it is the actual rotation, not the "target rotation"
65
49
public Quaternion GetJointRotation ( ConfigurableJoint joint )
66
50
{
67
- return ( Quaternion . FromToRotation ( joint . axis , joint . connectedBody . transform . rotation . eulerAngles ) ) ;
51
+ return ( Quaternion . FromToRotation ( joint . axis , joint . connectedBody . transform . rotation . eulerAngles ) ) ;
68
52
}
69
53
70
54
/// <summary>
@@ -78,7 +62,8 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
78
62
var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix . inverse . MultiplyVector ( rb . velocity ) ;
79
63
sensor . AddObservation ( velocityRelativeToLookRotationToTarget ) ;
80
64
81
- var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix . inverse . MultiplyVector ( rb . angularVelocity ) ;
65
+ var angularVelocityRelativeToLookRotationToTarget =
66
+ m_TargetDirMatrix . inverse . MultiplyVector ( rb . angularVelocity ) ;
82
67
sensor . AddObservation ( angularVelocityRelativeToLookRotationToTarget ) ;
83
68
84
69
if ( bp . rb . transform != bodySegment0 )
@@ -103,18 +88,19 @@ public override void CollectObservations(VectorSensor sensor)
103
88
float maxDist = 10 ;
104
89
if ( Physics . Raycast ( bodySegment0 . position , Vector3 . down , out hit , maxDist ) )
105
90
{
106
- sensor . AddObservation ( hit . distance / maxDist ) ;
91
+ sensor . AddObservation ( hit . distance / maxDist ) ;
107
92
}
108
93
else
109
94
sensor . AddObservation ( 1 ) ;
110
95
111
- foreach ( var bodyPart in m_JdController . bodyPartsDict . Values )
96
+ foreach ( var bodyPart in m_JdController . bodyPartsList )
112
97
{
113
98
CollectObservationBodyPart ( bodyPart , sensor ) ;
114
99
}
115
100
116
101
//Rotation delta between the matrix and the head
117
- Quaternion headRotationDeltaFromMatrixRot = Quaternion . Inverse ( m_TargetDirMatrix . rotation ) * bodySegment0 . rotation ;
102
+ Quaternion headRotationDeltaFromMatrixRot =
103
+ Quaternion . Inverse ( m_TargetDirMatrix . rotation ) * bodySegment0 . rotation ;
118
104
sensor . AddObservation ( headRotationDeltaFromMatrixRot ) ;
119
105
}
120
106
@@ -124,20 +110,6 @@ public override void CollectObservations(VectorSensor sensor)
124
110
public void TouchedTarget ( )
125
111
{
126
112
AddReward ( 1f ) ;
127
- if ( respawnTargetWhenTouched )
128
- {
129
- GetRandomTargetPos ( ) ;
130
- }
131
- }
132
-
133
- /// <summary>
134
- /// Moves target to a random position within specified radius.
135
- /// </summary>
136
- public void GetRandomTargetPos ( )
137
- {
138
- var newTargetPos = Random . insideUnitSphere * targetSpawnRadius ;
139
- newTargetPos . y = 5 ;
140
- target . position = newTargetPos + ground . position ;
141
113
}
142
114
143
115
public override void OnActionReceived ( float [ ] vectorAction )
@@ -156,25 +128,15 @@ public override void OnActionReceived(float[] vectorAction)
156
128
bpDict [ bodySegment2 ] . SetJointStrength ( vectorAction [ ++ i ] ) ;
157
129
bpDict [ bodySegment3 ] . SetJointStrength ( vectorAction [ ++ i ] ) ;
158
130
159
- if ( bodySegment0 . position . y < ground . position . y - 2 )
131
+ // Detect if worm fell off/through platform
132
+ if ( bodySegment0 . position . y < ground . position . y - 2 )
160
133
{
161
134
EndEpisode ( ) ;
162
135
}
163
136
}
164
137
165
138
void FixedUpdate ( )
166
139
{
167
- if ( detectTargets )
168
- {
169
- foreach ( var bodyPart in m_JdController . bodyPartsDict . Values )
170
- {
171
- if ( bodyPart . targetContact && bodyPart . targetContact . touchingTarget )
172
- {
173
- TouchedTarget ( ) ;
174
- }
175
- }
176
- }
177
-
178
140
// Set reward for this step according to mixture of the following elements.
179
141
if ( rewardMovingTowardsTarget )
180
142
{
@@ -197,7 +159,8 @@ void FixedUpdate()
197
159
/// </summary>
198
160
void RewardFunctionMovingTowards ( )
199
161
{
200
- m_MovingTowardsDot = Vector3 . Dot ( m_JdController . bodyPartsDict [ bodySegment0 ] . rb . velocity , m_DirToTarget . normalized ) ;
162
+ m_MovingTowardsDot =
163
+ Vector3 . Dot ( m_JdController . bodyPartsDict [ bodySegment0 ] . rb . velocity , m_DirToTarget . normalized ) ;
201
164
AddReward ( 0.01f * m_MovingTowardsDot ) ;
202
165
}
203
166
@@ -211,7 +174,7 @@ void RewardFunctionFacingTarget()
211
174
}
212
175
213
176
/// <summary>
214
- /// Existential penalty for time-contrained tasks.
177
+ /// Existential penalty for time-constrained tasks.
215
178
/// </summary>
216
179
void RewardFunctionTimePenalty ( )
217
180
{
@@ -227,15 +190,12 @@ public override void OnEpisodeBegin()
227
190
{
228
191
bodyPart . Reset ( bodyPart ) ;
229
192
}
193
+
230
194
if ( m_DirToTarget != Vector3 . zero )
231
195
{
232
196
transform . rotation = Quaternion . LookRotation ( m_DirToTarget ) ;
233
197
}
234
- transform . Rotate ( Vector3 . up , Random . Range ( 0.0f , 360.0f ) ) ;
235
198
236
- if ( ! targetIsStatic )
237
- {
238
- GetRandomTargetPos ( ) ;
239
- }
199
+ transform . Rotate ( Vector3 . up , Random . Range ( 0.0f , 360.0f ) ) ;
240
200
}
241
201
}
0 commit comments