@@ -47,7 +47,8 @@ public class CrawlerAgent : Agent
4747 bool isNewDecisionStep ;
4848 int currentDecisionStep ;
4949
50- private Transform workingTransform ;
50+ Quaternion lookRotation ;
51+ Matrix4x4 targetDirMatrix ;
5152
5253 public override void InitializeAgent ( )
5354 {
@@ -64,9 +65,7 @@ public override void InitializeAgent()
6465 jdController . SetupBodyPart ( leg2Lower ) ;
6566 jdController . SetupBodyPart ( leg3Upper ) ;
6667 jdController . SetupBodyPart ( leg3Lower ) ;
67-
68- workingTransform = new GameObject ( ) . transform ;
69- }
68+ }
7069
7170 /// <summary>
7271 /// We only need to change the joint settings based on decision freq.
@@ -94,8 +93,11 @@ public void CollectObservationBodyPart(BodyPart bp)
9493 var rb = bp . rb ;
9594 AddVectorObs ( bp . groundContact . touchingGround ? 1 : 0 ) ; // Whether the bp touching the ground
9695
97- AddVectorObs ( workingTransform . InverseTransformVector ( rb . velocity ) ) ;
98- AddVectorObs ( workingTransform . InverseTransformDirection ( rb . angularVelocity ) ) ;
96+ Vector3 velocityRelativeToLookRotationToTarget = targetDirMatrix . inverse . MultiplyVector ( rb . velocity ) ;
97+ AddVectorObs ( velocityRelativeToLookRotationToTarget ) ;
98+
99+ Vector3 angularVelocityRelativeToLookRotationToTarget = targetDirMatrix . inverse . MultiplyVector ( rb . angularVelocity ) ;
100+ AddVectorObs ( angularVelocityRelativeToLookRotationToTarget ) ;
99101
100102 if ( bp . rb . transform != body )
101103 {
@@ -113,17 +115,24 @@ public override void CollectObservations()
113115 jdController . GetCurrentJointForces ( ) ;
114116 // Normalize dir vector to help generalize
115117
116- workingTransform . rotation = Quaternion . LookRotation ( dirToTarget ) ;
118+ lookRotation = Quaternion . LookRotation ( dirToTarget ) ;
119+ targetDirMatrix = Matrix4x4 . TRS ( Vector3 . zero , lookRotation , Vector3 . one ) ;
120+
117121 // Forward & up to help with orientation
118122 RaycastHit hit ;
119123 if ( Physics . Raycast ( body . position , Vector3 . down , out hit , 10.0f ) )
120124 {
121125 AddVectorObs ( hit . distance ) ;
122126 }
123127 else
124- AddVectorObs ( 10.0f ) ;
125- AddVectorObs ( workingTransform . InverseTransformVector ( body . forward ) ) ;
126- AddVectorObs ( workingTransform . InverseTransformVector ( body . up ) ) ;
128+ AddVectorObs ( 10.0f ) ;
129+
130+ Vector3 bodyForwardRelativeToLookRotationToTarget = targetDirMatrix . inverse . MultiplyVector ( body . forward ) ;
131+ AddVectorObs ( bodyForwardRelativeToLookRotationToTarget ) ;
132+
133+ Vector3 bodyUpRelativeToLookRotationToTarget = targetDirMatrix . inverse . MultiplyVector ( body . up ) ;
134+ AddVectorObs ( bodyUpRelativeToLookRotationToTarget ) ;
135+
127136 foreach ( var bodyPart in jdController . bodyPartsDict . Values )
128137 {
129138 CollectObservationBodyPart ( bodyPart ) ;
0 commit comments