Skip to content

Commit d26a502

Browse files
mantaspErvin T
authored andcommitted
Barracuda hotfix for LSTM and tests (#2352)
* Removed obsolete 'TestDstWrongShape' test as it does not reflect how Barracuda tensors work * Added proper test cleanup, to avoid warning messages from finalizer thread. * Hotfix for recurrent + continous action nets in ML Agents
1 parent 7a2a922 commit d26a502

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

ml-agents/mlagents/trainers/tensorflow_to_barracuda.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
id=1,
5353
rank=2,
5454
out_shapes=lambda shapes: [
55-
[shapes[0][0], 1, 1, shapes[0][1]], # W
55+
[shapes[0][0], 1, 1, shapes[0][1]]
56+
if len(shapes[0]) > 1
57+
else [1, 1, 1, 1], # W
5658
[1, 1, 1, shapes[-1][-1]], # B
5759
],
5860
patch_data=lambda data: [data[0], data[1]],
@@ -324,9 +326,14 @@
324326
"ConcatV2",
325327
"Identity",
326328
]
327-
): "BasicLSTM",
328-
repr([re.compile("^lstm/"), "Reshape", "ConcatV2", "Identity"]): "BasicLSTM",
329-
repr(["Reshape", re.compile("^lstm_[a-z]*/"), "Reshape", "ConcatV2"]): "BasicLSTM",
329+
): "BasicLSTMReshapeOut",
330+
repr(
331+
[re.compile("^lstm/"), "Reshape", "ConcatV2", "Identity"]
332+
): "BasicLSTMReshapeOut",
333+
repr(
334+
["Reshape", re.compile("^lstm_[a-z]*/"), "Reshape", "ConcatV2"]
335+
): "BasicLSTMReshapeOut",
336+
repr(["Reshape", re.compile("^lstm_[a-z]*/"), "ConcatV2"]): "BasicLSTMConcatOut",
330337
repr(["Sigmoid", "Mul"]): "Swish",
331338
repr(["Mul", "Abs", "Mul", "Add"]): "LeakyRelu",
332339
repr(
@@ -546,9 +553,12 @@ def order_by(args, names):
546553
"SquaredDifference": lambda nodes, inputs, tensors, _: sqr_diff(
547554
nodes[-1].name, inputs[0], inputs[1]
548555
),
549-
"BasicLSTM": lambda nodes, inputs, tensors, context: basic_lstm(
556+
"BasicLSTMReshapeOut": lambda nodes, inputs, tensors, context: basic_lstm(
550557
nodes, inputs, tensors, context, find_type="Reshape"
551558
),
559+
"BasicLSTMConcatOut": lambda nodes, inputs, tensors, context: basic_lstm(
560+
nodes, inputs, tensors, context, find_type="ConcatV2"
561+
),
552562
"Swish": lambda nodes, inputs, tensors, _: Struct(op="Swish", input=inputs),
553563
"LeakyRelu": lambda nodes, inputs, tensors, _: Struct(op="LeakyRelu", input=inputs),
554564
# TODO:'Round'

0 commit comments

Comments
 (0)