Skip to content

Commit 02175a1

Browse files
Fix worm environment. (#4337)
* Fix worm environment. * Remove dead code from WormAgent.cs * remove target logic. change obsv loop to list instead of dict Co-authored-by: HH <[email protected]>
1 parent 0d4d08c commit 02175a1

File tree

2 files changed

+64
-58
lines changed

2 files changed

+64
-58
lines changed

Project/Assets/ML-Agents/Examples/Worm/Prefabs/PlatformWormDynamicTarget.prefab

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,41 @@ Rigidbody:
1616
m_Interpolate: 0
1717
m_Constraints: 0
1818
m_CollisionDetection: 0
19+
--- !u!114 &8042564747579887
20+
MonoBehaviour:
21+
m_ObjectHideFlags: 0
22+
m_CorrespondingSourceObject: {fileID: 0}
23+
m_PrefabInstance: {fileID: 0}
24+
m_PrefabAsset: {fileID: 0}
25+
m_GameObject: {fileID: 7516457449653310668}
26+
m_Enabled: 1
27+
m_EditorHideFlags: 0
28+
m_Script: {fileID: 11500000, guid: 3c8f113a8b8d94967b1b1782c549be81, type: 3}
29+
m_Name:
30+
m_EditorClassIdentifier:
31+
tagToDetect: agent
32+
spawnRadius: 40
33+
respawnIfTouched: 1
34+
respawnIfFallsOffPlatform: 1
35+
fallDistance: 5
36+
onTriggerEnterEvent:
37+
m_PersistentCalls:
38+
m_Calls: []
39+
onTriggerStayEvent:
40+
m_PersistentCalls:
41+
m_Calls: []
42+
onTriggerExitEvent:
43+
m_PersistentCalls:
44+
m_Calls: []
45+
onCollisionEnterEvent:
46+
m_PersistentCalls:
47+
m_Calls: []
48+
onCollisionStayEvent:
49+
m_PersistentCalls:
50+
m_Calls: []
51+
onCollisionExitEvent:
52+
m_PersistentCalls:
53+
m_Calls: []
1954
--- !u!1001 &906401165941233076
2055
PrefabInstance:
2156
m_ObjectHideFlags: 0
@@ -93,6 +128,11 @@ PrefabInstance:
93128
propertyPath: ground
94129
value:
95130
objectReference: {fileID: 7519759559437056804}
131+
- target: {fileID: 6060305997946326746, guid: 3ebcde4cf2d5c4c029e2a5ce3d853aba,
132+
type: 3}
133+
propertyPath: m_TagString
134+
value: agent
135+
objectReference: {fileID: 0}
96136
m_RemovedComponents: []
97137
m_SourcePrefab: {fileID: 100100000, guid: 3ebcde4cf2d5c4c029e2a5ce3d853aba, type: 3}
98138
--- !u!1001 &7202236613889278392
@@ -202,6 +242,12 @@ Transform:
202242
type: 3}
203243
m_PrefabInstance: {fileID: 7202236613889278392}
204244
m_PrefabAsset: {fileID: 0}
245+
--- !u!1 &7516457449653310668 stripped
246+
GameObject:
247+
m_CorrespondingSourceObject: {fileID: 845742365997159796, guid: d6fc96a99a9754f07b48abf1e0d55a5c,
248+
type: 3}
249+
m_PrefabInstance: {fileID: 7202236613889278392}
250+
m_PrefabAsset: {fileID: 0}
205251
--- !u!4 &7513373574146463010 stripped
206252
Transform:
207253
m_CorrespondingSourceObject: {fileID: 844321025358320794, guid: d6fc96a99a9754f07b48abf1e0d55a5c,

Project/Assets/ML-Agents/Examples/Worm/Scripts/WormAgent.cs

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,21 @@
66
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
77
public class WormAgent : Agent
88
{
9-
[Header("Target To Walk Towards")]
10-
[Space(10)]
9+
[Header("Target To Walk Towards")] [Space(10)]
1110
public Transform target;
12-
1311
public Transform ground;
14-
public bool detectTargets;
15-
public bool targetIsStatic;
16-
public bool respawnTargetWhenTouched;
17-
public float targetSpawnRadius;
1812

19-
[Header("Body Parts")] [Space(10)]
20-
public Transform bodySegment0;
13+
[Header("Body Parts")] [Space(10)] public Transform bodySegment0;
2114
public Transform bodySegment1;
2215
public Transform bodySegment2;
2316
public Transform bodySegment3;
2417

25-
[Header("Joint Settings")] [Space(10)]
26-
JointDriveController m_JdController;
18+
[Header("Joint Settings")] [Space(10)] JointDriveController m_JdController;
2719
Vector3 m_DirToTarget;
2820
float m_MovingTowardsDot;
2921
float m_FacingDot;
3022

31-
[Header("Reward Functions To Use")]
32-
[Space(10)]
23+
[Header("Reward Functions To Use")] [Space(10)]
3324
public bool rewardMovingTowardsTarget; // Agent should move towards target
3425

3526
public bool rewardFacingTarget; // Agent should face the target
@@ -50,21 +41,14 @@ public override void Initialize()
5041
m_JdController.SetupBodyPart(bodySegment1);
5142
m_JdController.SetupBodyPart(bodySegment2);
5243
m_JdController.SetupBodyPart(bodySegment3);
53-
54-
//We only want the head to detect the target
55-
//So we need to remove TargetContact from everything else
56-
//This is a temp fix till we can redesign
57-
DestroyImmediate(bodySegment1.GetComponent<TargetContact>());
58-
DestroyImmediate(bodySegment2.GetComponent<TargetContact>());
59-
DestroyImmediate(bodySegment3.GetComponent<TargetContact>());
6044
}
6145

6246

6347
//Get Joint Rotation Relative to the Connected Rigidbody
6448
//We want to collect this info because it is the actual rotation, not the "target rotation"
6549
public Quaternion GetJointRotation(ConfigurableJoint joint)
6650
{
67-
return(Quaternion.FromToRotation(joint.axis, joint.connectedBody.transform.rotation.eulerAngles));
51+
return (Quaternion.FromToRotation(joint.axis, joint.connectedBody.transform.rotation.eulerAngles));
6852
}
6953

7054
/// <summary>
@@ -78,7 +62,8 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
7862
var velocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.velocity);
7963
sensor.AddObservation(velocityRelativeToLookRotationToTarget);
8064

81-
var angularVelocityRelativeToLookRotationToTarget = m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
65+
var angularVelocityRelativeToLookRotationToTarget =
66+
m_TargetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
8267
sensor.AddObservation(angularVelocityRelativeToLookRotationToTarget);
8368

8469
if (bp.rb.transform != bodySegment0)
@@ -103,18 +88,19 @@ public override void CollectObservations(VectorSensor sensor)
10388
float maxDist = 10;
10489
if (Physics.Raycast(bodySegment0.position, Vector3.down, out hit, maxDist))
10590
{
106-
sensor.AddObservation(hit.distance/maxDist);
91+
sensor.AddObservation(hit.distance / maxDist);
10792
}
10893
else
10994
sensor.AddObservation(1);
11095

111-
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
96+
foreach (var bodyPart in m_JdController.bodyPartsList)
11297
{
11398
CollectObservationBodyPart(bodyPart, sensor);
11499
}
115100

116101
//Rotation delta between the matrix and the head
117-
Quaternion headRotationDeltaFromMatrixRot = Quaternion.Inverse(m_TargetDirMatrix.rotation) * bodySegment0.rotation;
102+
Quaternion headRotationDeltaFromMatrixRot =
103+
Quaternion.Inverse(m_TargetDirMatrix.rotation) * bodySegment0.rotation;
118104
sensor.AddObservation(headRotationDeltaFromMatrixRot);
119105
}
120106

@@ -124,20 +110,6 @@ public override void CollectObservations(VectorSensor sensor)
124110
public void TouchedTarget()
125111
{
126112
AddReward(1f);
127-
if (respawnTargetWhenTouched)
128-
{
129-
GetRandomTargetPos();
130-
}
131-
}
132-
133-
/// <summary>
134-
/// Moves target to a random position within specified radius.
135-
/// </summary>
136-
public void GetRandomTargetPos()
137-
{
138-
var newTargetPos = Random.insideUnitSphere * targetSpawnRadius;
139-
newTargetPos.y = 5;
140-
target.position = newTargetPos + ground.position;
141113
}
142114

143115
public override void OnActionReceived(float[] vectorAction)
@@ -156,25 +128,15 @@ public override void OnActionReceived(float[] vectorAction)
156128
bpDict[bodySegment2].SetJointStrength(vectorAction[++i]);
157129
bpDict[bodySegment3].SetJointStrength(vectorAction[++i]);
158130

159-
if (bodySegment0.position.y < ground.position.y -2)
131+
// Detect if worm fell off/through platform
132+
if (bodySegment0.position.y < ground.position.y - 2)
160133
{
161134
EndEpisode();
162135
}
163136
}
164137

165138
void FixedUpdate()
166139
{
167-
if (detectTargets)
168-
{
169-
foreach (var bodyPart in m_JdController.bodyPartsDict.Values)
170-
{
171-
if (bodyPart.targetContact && bodyPart.targetContact.touchingTarget)
172-
{
173-
TouchedTarget();
174-
}
175-
}
176-
}
177-
178140
// Set reward for this step according to mixture of the following elements.
179141
if (rewardMovingTowardsTarget)
180142
{
@@ -197,7 +159,8 @@ void FixedUpdate()
197159
/// </summary>
198160
void RewardFunctionMovingTowards()
199161
{
200-
m_MovingTowardsDot = Vector3.Dot(m_JdController.bodyPartsDict[bodySegment0].rb.velocity, m_DirToTarget.normalized);
162+
m_MovingTowardsDot =
163+
Vector3.Dot(m_JdController.bodyPartsDict[bodySegment0].rb.velocity, m_DirToTarget.normalized);
201164
AddReward(0.01f * m_MovingTowardsDot);
202165
}
203166

@@ -211,7 +174,7 @@ void RewardFunctionFacingTarget()
211174
}
212175

213176
/// <summary>
214-
/// Existential penalty for time-contrained tasks.
177+
/// Existential penalty for time-constrained tasks.
215178
/// </summary>
216179
void RewardFunctionTimePenalty()
217180
{
@@ -227,15 +190,12 @@ public override void OnEpisodeBegin()
227190
{
228191
bodyPart.Reset(bodyPart);
229192
}
193+
230194
if (m_DirToTarget != Vector3.zero)
231195
{
232196
transform.rotation = Quaternion.LookRotation(m_DirToTarget);
233197
}
234-
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f));
235198

236-
if (!targetIsStatic)
237-
{
238-
GetRandomTargetPos();
239-
}
199+
transform.Rotate(Vector3.up, Random.Range(0.0f, 360.0f));
240200
}
241201
}

0 commit comments

Comments
 (0)