Skip to content

Commit 18f6616

Browse files
Uses matrix maths for observation vectors.
1 parent 74d82a5 commit 18f6616

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

UnitySDK/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ public class CrawlerAgent : Agent
4747
bool isNewDecisionStep;
4848
int currentDecisionStep;
4949

50-
private Transform workingTransform;
50+
Quaternion lookRotation;
51+
Matrix4x4 targetDirMatrix;
5152

5253
public override void InitializeAgent()
5354
{
@@ -64,9 +65,7 @@ public override void InitializeAgent()
6465
jdController.SetupBodyPart(leg2Lower);
6566
jdController.SetupBodyPart(leg3Upper);
6667
jdController.SetupBodyPart(leg3Lower);
67-
68-
workingTransform = new GameObject().transform;
69-
}
68+
}
7069

7170
/// <summary>
7271
/// We only need to change the joint settings based on decision freq.
@@ -94,8 +93,11 @@ public void CollectObservationBodyPart(BodyPart bp)
9493
var rb = bp.rb;
9594
AddVectorObs(bp.groundContact.touchingGround ? 1 : 0); // Whether the bp touching the ground
9695

97-
AddVectorObs(workingTransform.InverseTransformVector(rb.velocity));
98-
AddVectorObs(workingTransform.InverseTransformDirection(rb.angularVelocity));
96+
Vector3 velocityRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(rb.velocity);
97+
AddVectorObs(velocityRelativeToLookRotationToTarget);
98+
99+
Vector3 angularVelocityRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(rb.angularVelocity);
100+
AddVectorObs(angularVelocityRelativeToLookRotationToTarget);
99101

100102
if (bp.rb.transform != body)
101103
{
@@ -113,17 +115,24 @@ public override void CollectObservations()
113115
jdController.GetCurrentJointForces();
114116
// Normalize dir vector to help generalize
115117

116-
workingTransform.rotation = Quaternion.LookRotation(dirToTarget);
118+
lookRotation = Quaternion.LookRotation(dirToTarget);
119+
targetDirMatrix = Matrix4x4.TRS(Vector3.zero, lookRotation, Vector3.one);
120+
117121
// Forward & up to help with orientation
118122
RaycastHit hit;
119123
if (Physics.Raycast(body.position, Vector3.down, out hit, 10.0f))
120124
{
121125
AddVectorObs(hit.distance);
122126
}
123127
else
124-
AddVectorObs(10.0f);
125-
AddVectorObs(workingTransform.InverseTransformVector(body.forward));
126-
AddVectorObs(workingTransform.InverseTransformVector(body.up));
128+
AddVectorObs(10.0f);
129+
130+
Vector3 bodyForwardRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(body.forward);
131+
AddVectorObs(bodyForwardRelativeToLookRotationToTarget);
132+
133+
Vector3 bodyUpRelativeToLookRotationToTarget = targetDirMatrix.inverse.MultiplyVector(body.up);
134+
AddVectorObs(bodyUpRelativeToLookRotationToTarget);
135+
127136
foreach (var bodyPart in jdController.bodyPartsDict.Values)
128137
{
129138
CollectObservationBodyPart(bodyPart);

0 commit comments

Comments
 (0)