Skip to content

Commit e8e3ab1

Browse files
xaduprevinitra
authored andcommitted
include onnxutils tests in CI, add a unit test on DataTypes, update optimizer (#283)
* add a unit test on DataTypes * update CI with onnxutils pytest (for onnxconverter-common) * list() instead of list.copy() for python 2.7 compatibility in optimizer
1 parent d5bd3cd commit e8e3ab1

File tree

5 files changed

+39
-4
lines changed

5 files changed

+39
-4
lines changed

.azure-pipelines/linux-conda-CI.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ jobs:
5555
python -c "import onnxconverter_common"
5656
test '$(python.version)' != '2.7' && python -c "import onnxruntime"
5757
pytest tests --doctest-modules --junitxml=junit/test-results.xml
58-
displayName: 'pytest'
58+
displayName: 'pytest - onnxmltools'
59+
60+
- script: |
61+
export PYTHONPATH=$PYTHONPATH:libsvm/python
62+
python -c "import onnxconverter_common"
63+
pytest onnxutils/tests --doctest-modules --junitxml=junit/test-results-onnxutils.xml
64+
displayName: 'pytest - onnxutils'
5965
6066
- task: PublishTestResults@2
6167
inputs:

.azure-pipelines/win32-conda-CI.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ jobs:
6767
set PYTHONPATH=libsvm\python;%PYTHONPATH%
6868
pip install -e .
6969
pytest tests --doctest-modules --junitxml=junit/test-results.xml
70-
displayName: 'pytest'
70+
displayName: 'pytest - onnxmltools'
71+
72+
- script: |
73+
call activate py$(python.version)
74+
set PYTHONPATH=libsvm\python;%PYTHONPATH%
75+
pytest onnxutils/tests --doctest-modules --junitxml=junit/test-results-onnxutils.xml
76+
displayName: 'pytest - onnxutils'
7177
7278
- task: PublishTestResults@2
7379
inputs:

onnxutils/onnxconverter_common/data_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import onnx
99
from onnx import onnx_pb as onnx_proto
1010

11+
1112
class DataType(object):
1213
def __init__(self, shape=None, doc_string=''):
1314
self.shape = shape

onnxutils/onnxconverter_common/optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ class FanOutSolution(Solution):
381381
number = 0
382382
def apply(self, node_list):
383383
cur_perm = Solution.get_perm(self.begin_n.origin)
384-
successor_list = self.end_p.successor.copy()
384+
# make a copy of self.end_p.successor
385+
successor_list = list(self.end_p.successor)
385386

386387
for suc in successor_list:
387388
nnode = LinkedNode(
@@ -413,7 +414,8 @@ def apply(self, node_list):
413414
perm=self.perm,
414415
name='TransposeFanIn' + str(FanInSolution.number)))
415416
FanInSolution.number = FanInSolution.number + 1
416-
precedence_list = self.begin.precedence.copy()
417+
# make a copy of self.begin.precedence
418+
precedence_list = list(self.begin.precedence)
417419
node_list = Solution.add_siso_node(node_list, self.begin, self.begin_n, list(self.begin.output.values())[0], nnode)
418420
for branch in precedence_list:
419421
node_list = Solution.delete_node_1ton(node_list, branch.precedence[0], branch, self.begin)

onnxutils/tests/test_onnx.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
from onnxconverter_common.data_types import FloatTensorType
3+
4+
5+
class TestTypes(unittest.TestCase):
6+
7+
def test_to_onnx_type(self):
8+
dt = FloatTensorType((1, 5))
9+
assert str(dt) == 'FloatTensorType(shape=(1, 5))'
10+
onx = dt.to_onnx_type()
11+
assert "dim_value: 5" in str(onx)
12+
tt = onx.tensor_type
13+
assert "dim_value: 5" in str(tt)
14+
assert tt.elem_type == 1
15+
o = onx.sequence_type
16+
assert str(o) == ""
17+
18+
19+
if __name__ == '__main__':
20+
unittest.main()

0 commit comments

Comments
 (0)