Skip to content

Commit 1033e59

Browse files
Add files via upload
1 parent 64bf821 commit 1033e59

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ public class CrawlerAgent : Agent
4747
bool isNewDecisionStep;
4848
int currentDecisionStep;
4949

50+
private Transform workingTransform;
51+
5052
public override void InitializeAgent()
5153
{
5254
jdController = GetComponent<JointDriveController>();
@@ -62,6 +64,8 @@ public override void InitializeAgent()
6264
jdController.SetupBodyPart(leg2Lower);
6365
jdController.SetupBodyPart(leg3Upper);
6466
jdController.SetupBodyPart(leg3Lower);
67+
68+
workingTransform = new GameObject().transform;
6569
}
6670

6771
/// <summary>
@@ -89,8 +93,9 @@ public void CollectObservationBodyPart(BodyPart bp)
8993
{
9094
var rb = bp.rb;
9195
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
92-
AddVectorObs(rb.velocity);
93-
AddVectorObs(rb.angularVelocity);
96+
97+
AddVectorObs(workingTransform.InverseTransformVector(rb.velocity));
98+
AddVectorObs(workingTransform.InverseTransformDirection(rb.angularVelocity));
9499

95100
if (bp.rb.transform != body)
96101
{
@@ -107,12 +112,18 @@ public override void CollectObservations()
107112
{
108113
jdController.GetCurrentJointForces();
109114
// Normalize dir vector to help generalize
110-
AddVectorObs(dirToTarget.normalized);
111115

116+
workingTransform.rotation = Quaternion.LookRotation(dirToTarget);
112117
// Forward & up to help with orientation
113-
AddVectorObs(body.transform.position.y);
114-
AddVectorObs(body.forward);
115-
AddVectorObs(body.up);
118+
RaycastHit hit;
119+
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f))
120+
{
121+
AddVectorObs(hit.distance);
122+
}
123+
else
124+
AddVectorObs(10.0f);
125+
AddVectorObs(workingTransform.InverseTransformVector(body.forward));
126+
AddVectorObs(workingTransform.InverseTransformVector(body.up));
116127
foreach (var bodyPart in jdController.bodyPartsDict.Values)
117128
{
118129
CollectObservationBodyPart(bodyPart);
@@ -257,6 +268,7 @@ public override void AgentReset()
257268
{
258269
transform.rotation = Quaternion.LookRotation(dirToTarget);
259270
}
271+
transform.Rotate(Vector3.up,Random.Range(0.0f, 360.0f));
260272

261273
foreach (var bodyPart in jdController.bodyPartsDict.Values)
262274
{

0 commit comments

Comments
 (0)