Skip to content

Commit 7d4159a

Browse files
committed
update tests to latest unity models
1 parent 6566571 commit 7d4159a

File tree

3 files changed

+25
-29
lines changed

3 files changed

+25
-29
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ python -m tf2onnx.convert\
8080
--verbose
8181
```
8282
Some models specify placeholders with unknown ranks which can not be mapped to onnx.
83-
In those cases one can add the shape behind the input name in ```[]```, for example ```--input X:0[1,28,28,3]```
83+
In those cases one can add the shape behind the input name in ```[]```, for example ```--inputs X:0[1,28,28,3]```
8484

8585
## <a name="summarize_graph"></a>Tool to get Graph Inputs & Outputs
8686

tests/unity.yaml

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
model: 3DBall.bytes
55
input_get: get_random
66
inputs:
7-
"vector_observation:0": [1, 24]
7+
"vector_observation:0": [1, 8]
88
outputs:
99
- action:0
10+
- action_probs:0
11+
- value_estimate:0
1012

1113
3DBallHard:
1214
# needs: lstm
@@ -21,26 +23,6 @@
2123
outputs:
2224
- action:0
2325

24-
BananaIL:
25-
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaIL.bytes
26-
model: BananaIL.bytes
27-
input_get: get_random
28-
inputs:
29-
"vector_observation:0": [1, 159]
30-
outputs:
31-
- action:0
32-
33-
BananaRL:
34-
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaRL.bytes
35-
model: BananaRL.bytes
36-
check_only_shape: true
37-
input_get: get_random
38-
inputs:
39-
"vector_observation:0": [1, 159]
40-
outputs:
41-
- action_probs:0
42-
- value_estimate:0
43-
4426
Basic:
4527
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Basic/TFModels/Basic.bytes
4628
model: Basic.bytes
@@ -63,16 +45,29 @@ Bouncer:
6345
- action_probs:0
6446
- value_estimate:0
6547

66-
crawler:
48+
DynamicCrawler:
6749
check_only_shape: true
68-
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Crawler/TFModels/crawler.bytes
69-
model: crawler.bytes
50+
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Crawler/TFModels/DynamicCrawler.bytes
51+
model: DynamicCrawler.bytes
7052
input_get: get_random
7153
inputs:
72-
"vector_observation:0": [1, 117]
54+
"vector_observation:0": [1, 129]
7355
outputs:
7456
- action_probs:0
7557
- value_estimate:0
58+
- action:0
59+
60+
FixedCrawler:
61+
check_only_shape: true
62+
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Crawler/TFModels/FixedCrawler.bytes
63+
model: FixedCrawler.bytes
64+
input_get: get_random
65+
inputs:
66+
"vector_observation:0": [1, 129]
67+
outputs:
68+
- action_probs:0
69+
- value_estimate:0
70+
- action:0
7671

7772
GridWorld_3x3:
7873
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_3x3.bytes
@@ -131,10 +126,11 @@ Reacher:
131126
model: Reacher.bytes
132127
input_get: get_random
133128
inputs:
134-
"vector_observation:0": [1, 78]
129+
"vector_observation:0": [1, 33]
135130
outputs:
136-
- value_estimate:0
137131
- action:0
132+
- value_estimate:0
133+
- action_probs:0
138134

139135
Soccer:
140136
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Soccer/TFModels/Soccer.bytes

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ def minmax_op(ctx, node, name, args):
911911
new_nodes = []
912912
for i in needs_broadcast_op:
913913
input_node = node.inputs[i]
914-
dtype = ctx.dtypes[node.input[i]]
914+
dtype = ctx.get_dtype(node.input[i])
915915
zero_name = utils.make_name(input_node.name)
916916
ctx.make_const(zero_name, "Const", np.zeros(shapeo, dtype=utils.ONNX_TO_NUMPY_DTYPE[dtype]))
917917
op_name = utils.make_name(input_node.name)

0 commit comments

Comments
 (0)