Skip to content

Commit 9ce7a8e

Browse files
Ruo-Ping (Rachel) DongvincentpierreChris Elion
authored
Support multi-dimensional and compressed observations stacking (#4476)
Added stacking to multi-dimensional and compressed observations and added compressed channel mapping in communicator to support decompression. Co-authored-by: Vincent-Pierre BERGES <[email protected]> Co-authored-by: Chris Elion <[email protected]>
1 parent 3818a21 commit 9ce7a8e

28 files changed

+832
-97
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,19 @@ and this project adheres to
1313
- Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch
1414
trainers. To use RND, add a `rnd` section to the `reward_signals` section of your
1515
yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward)
16-
1716
### Minor Changes
1817
#### com.unity.ml-agents (C#)
18+
- Stacking for compressed observations is now supported. An addtional setting
19+
option `Observation Stacks` is added in editor to sensor components that support
20+
compressed observations. A new class `ISparseChannelSensor` with an
21+
additional method `GetCompressedChannelMapping()`is added to generate a mapping
22+
of the channels in compressed data to the actual channel after decompression,
23+
for the python side to decompress correctly. (#4476)
1924
#### ml-agents / ml-agents-envs / gym-unity (Python)
20-
25+
- The Communication API was changed to 1.2.0 to indicate support for stacked
26+
compressed observation. A new entry `compressed_channel_mapping` is added to the
27+
proto to handle decompression correctly. Newer versions of the package that wish to
28+
make use of this will also need a compatible version of the Python trainers. (#4476)
2129
### Bug Fixes
2230
#### com.unity.ml-agents (C#)
2331
#### ml-agents / ml-agents-envs / gym-unity (Python)

com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public override void OnInspectorGUI()
2525
EditorGUILayout.PropertyField(so.FindProperty("m_Width"), true);
2626
EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
2727
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
28+
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
2829
}
2930
EditorGUI.EndDisabledGroup();
3031
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);

com.unity.ml-agents/Editor/RenderTextureSensorComponentEditor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ public override void OnInspectorGUI()
2020
EditorGUILayout.PropertyField(so.FindProperty("m_RenderTexture"), true);
2121
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
2222
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
23+
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
2324
}
2425
EditorGUI.EndDisabledGroup();
2526

com.unity.ml-agents/Runtime/Academy.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,13 @@ public class Academy : IDisposable
7474
/// <term>1.1.0</term>
7575
/// <description>Support concatenated PNGs for compressed observations.</description>
7676
/// </item>
77+
/// <item>
78+
/// <term>1.2.0</term>
79+
/// <description>Support compression mapping for stacked compressed observations.</description>
80+
/// </item>
7781
/// </list>
7882
/// </remarks>
79-
const string k_ApiVersion = "1.1.0";
83+
const string k_ApiVersion = "1.2.0";
8084

8185
/// <summary>
8286
/// Unity package version of com.unity.ml-agents.

com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ public static List<float[]> ToAgentActionList(this UnityRLInputProto.Types.ListA
222222
/// <summary>
223223
/// Static flag to make sure that we only fire the warning once.
224224
/// </summary>
225-
private static bool s_HaveWarnedAboutTrainerCapabilities = false;
225+
private static bool s_HaveWarnedTrainerCapabilitiesMultiPng = false;
226+
private static bool s_HaveWarnedTrainerCapabilitiesMapping = false;
226227

227228
/// <summary>
228229
/// Generate an ObservationProto for the sensor using the provided ObservationWriter.
@@ -243,10 +244,27 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
243244
var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
244245
if (!trainerCanHandle)
245246
{
246-
if (!s_HaveWarnedAboutTrainerCapabilities)
247+
if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
247248
{
248249
Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}.");
249-
s_HaveWarnedAboutTrainerCapabilities = true;
250+
s_HaveWarnedTrainerCapabilitiesMultiPng = true;
251+
}
252+
compressionType = SensorCompressionType.None;
253+
}
254+
}
255+
// Check capabilities if we need mapping for compressed observations
256+
if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
257+
{
258+
var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
259+
var isTrivialMapping = IsTrivialMapping(sensor);
260+
if (!trainerCanHandleMapping && !isTrivialMapping)
261+
{
262+
if (!s_HaveWarnedTrainerCapabilitiesMapping)
263+
{
264+
Debug.LogWarning($"The sensor {sensor.GetName()} is using non-trivial mapping and " +
265+
"the attached trainer doesn't support compression mapping. " +
266+
"Switching to uncompressed observations.");
267+
s_HaveWarnedTrainerCapabilitiesMapping = true;
250268
}
251269
compressionType = SensorCompressionType.None;
252270
}
@@ -283,12 +301,16 @@ public static ObservationProto GetObservationProto(this ISensor sensor, Observat
283301
"return SensorCompressionType.None from GetCompressionType()."
284302
);
285303
}
286-
287304
observationProto = new ObservationProto
288305
{
289306
CompressedData = ByteString.CopyFrom(compressedObs),
290307
CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
291308
};
309+
var compressibleSensor = sensor as ISparseChannelSensor;
310+
if (compressibleSensor != null)
311+
{
312+
observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
313+
}
292314
}
293315
observationProto.Shape.AddRange(shape);
294316
return observationProto;
@@ -300,7 +322,8 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto
300322
return new UnityRLCapabilities
301323
{
302324
BaseRLCapabilities = proto.BaseRLCapabilities,
303-
ConcatenatedPngObservations = proto.ConcatenatedPngObservations
325+
ConcatenatedPngObservations = proto.ConcatenatedPngObservations,
326+
CompressedChannelMapping = proto.CompressedChannelMapping,
304327
};
305328
}
306329

@@ -310,7 +333,36 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
310333
{
311334
BaseRLCapabilities = rlCaps.BaseRLCapabilities,
312335
ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations,
336+
CompressedChannelMapping = rlCaps.CompressedChannelMapping,
313337
};
314338
}
339+
340+
internal static bool IsTrivialMapping(ISensor sensor)
341+
{
342+
var compressibleSensor = sensor as ISparseChannelSensor;
343+
if (compressibleSensor is null)
344+
{
345+
return true;
346+
}
347+
var mapping = compressibleSensor.GetCompressedChannelMapping();
348+
if (mapping == null)
349+
{
350+
return true;
351+
}
352+
// check if mapping equals zero mapping
353+
if (mapping.Length == 3 && mapping.All(m => m == 0))
354+
{
355+
return true;
356+
}
357+
// check if mapping equals identity mapping
358+
for (var i = 0; i < mapping.Length; i++)
359+
{
360+
if (mapping[i] != i)
361+
{
362+
return false;
363+
}
364+
}
365+
return true;
366+
}
315367
}
316368
}

com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ internal class UnityRLCapabilities
66
{
77
public bool BaseRLCapabilities;
88
public bool ConcatenatedPngObservations;
9+
public bool CompressedChannelMapping;
910

1011
/// <summary>
1112
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
1213
/// struct will be used to inform users if and when they are using C# / Trainer features that are mismatched.
1314
/// </summary>
14-
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true)
15+
public UnityRLCapabilities(bool baseRlCapabilities = true, bool concatenatedPngObservations = true, bool compressedChannelMapping = true)
1516
{
1617
BaseRLCapabilities = baseRlCapabilities;
1718
ConcatenatedPngObservations = concatenatedPngObservations;
19+
CompressedChannelMapping = compressedChannelMapping;
1820
}
1921

2022
/// <summary>

com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ static CapabilitiesReflection() {
2525
byte[] descriptorData = global::System.Convert.FromBase64String(
2626
string.Concat(
2727
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
28-
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMiWwoYVW5pdHlSTENh",
28+
"dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMifQoYVW5pdHlSTENh",
2929
"cGFiaWxpdGllc1Byb3RvEhoKEmJhc2VSTENhcGFiaWxpdGllcxgBIAEoCBIj",
30-
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAhCJaoCIlVuaXR5",
31-
"Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
30+
"Chtjb25jYXRlbmF0ZWRQbmdPYnNlcnZhdGlvbnMYAiABKAgSIAoYY29tcHJl",
31+
"c3NlZENoYW5uZWxNYXBwaW5nGAMgASgIQiWqAiJVbml0eS5NTEFnZW50cy5D",
32+
"b21tdW5pY2F0b3JPYmplY3RzYgZwcm90bzM="));
3233
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
3334
new pbr::FileDescriptor[] { },
3435
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
35-
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations" }, null, null, null)
36+
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping" }, null, null, null)
3637
}));
3738
}
3839
#endregion
@@ -71,6 +72,7 @@ public UnityRLCapabilitiesProto() {
7172
public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
7273
baseRLCapabilities_ = other.baseRLCapabilities_;
7374
concatenatedPngObservations_ = other.concatenatedPngObservations_;
75+
compressedChannelMapping_ = other.compressedChannelMapping_;
7476
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
7577
}
7678

@@ -107,6 +109,20 @@ public bool ConcatenatedPngObservations {
107109
}
108110
}
109111

112+
/// <summary>Field number for the "compressedChannelMapping" field.</summary>
113+
public const int CompressedChannelMappingFieldNumber = 3;
114+
private bool compressedChannelMapping_;
115+
/// <summary>
116+
/// compression mapping for stacking compressed observations.
117+
/// </summary>
118+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
119+
public bool CompressedChannelMapping {
120+
get { return compressedChannelMapping_; }
121+
set {
122+
compressedChannelMapping_ = value;
123+
}
124+
}
125+
110126
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
111127
public override bool Equals(object other) {
112128
return Equals(other as UnityRLCapabilitiesProto);
@@ -122,6 +138,7 @@ public bool Equals(UnityRLCapabilitiesProto other) {
122138
}
123139
if (BaseRLCapabilities != other.BaseRLCapabilities) return false;
124140
if (ConcatenatedPngObservations != other.ConcatenatedPngObservations) return false;
141+
if (CompressedChannelMapping != other.CompressedChannelMapping) return false;
125142
return Equals(_unknownFields, other._unknownFields);
126143
}
127144

@@ -130,6 +147,7 @@ public override int GetHashCode() {
130147
int hash = 1;
131148
if (BaseRLCapabilities != false) hash ^= BaseRLCapabilities.GetHashCode();
132149
if (ConcatenatedPngObservations != false) hash ^= ConcatenatedPngObservations.GetHashCode();
150+
if (CompressedChannelMapping != false) hash ^= CompressedChannelMapping.GetHashCode();
133151
if (_unknownFields != null) {
134152
hash ^= _unknownFields.GetHashCode();
135153
}
@@ -151,6 +169,10 @@ public void WriteTo(pb::CodedOutputStream output) {
151169
output.WriteRawTag(16);
152170
output.WriteBool(ConcatenatedPngObservations);
153171
}
172+
if (CompressedChannelMapping != false) {
173+
output.WriteRawTag(24);
174+
output.WriteBool(CompressedChannelMapping);
175+
}
154176
if (_unknownFields != null) {
155177
_unknownFields.WriteTo(output);
156178
}
@@ -165,6 +187,9 @@ public int CalculateSize() {
165187
if (ConcatenatedPngObservations != false) {
166188
size += 1 + 1;
167189
}
190+
if (CompressedChannelMapping != false) {
191+
size += 1 + 1;
192+
}
168193
if (_unknownFields != null) {
169194
size += _unknownFields.CalculateSize();
170195
}
@@ -182,6 +207,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) {
182207
if (other.ConcatenatedPngObservations != false) {
183208
ConcatenatedPngObservations = other.ConcatenatedPngObservations;
184209
}
210+
if (other.CompressedChannelMapping != false) {
211+
CompressedChannelMapping = other.CompressedChannelMapping;
212+
}
185213
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
186214
}
187215

@@ -201,6 +229,10 @@ public void MergeFrom(pb::CodedInputStream input) {
201229
ConcatenatedPngObservations = input.ReadBool();
202230
break;
203231
}
232+
case 24: {
233+
CompressedChannelMapping = input.ReadBool();
234+
break;
235+
}
204236
}
205237
}
206238
}

0 commit comments

Comments
 (0)