Skip to content

Commit 71f4d68

Browse files
committed
small changes for test cases
1 parent 8b5fba2 commit 71f4d68

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

tests/test_optimizers.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ def test_transpose_max(self):
184184
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1,), const_1_val)
185185
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
186186

187-
const_2_val = np.random.randn(2, 4, 5, 3).astype(np.float32).reshape(120).tolist()
188-
const_2 = helper.make_tensor("const_2", TensorProto.FLOAT, (2, 4, 5, 3), const_2_val)
187+
const_2_val = np.random.randn(2, 4, 5, 3).astype(np.float32)
188+
const_2 = helper.make_tensor("const_2", TensorProto.FLOAT, (2, 4, 5, 3), const_2_val.flatten())
189189
const_2_node = helper.make_node("Constant", [], ["const_2"], value=const_2, name="const_2")
190190

191-
const_3_val = np.random.randn(2, 4, 5, 3).astype(np.float32).reshape(120).tolist()
192-
const_3 = helper.make_tensor("const_3", TensorProto.FLOAT, (2, 4, 5, 3), const_3_val)
191+
const_3_val = np.random.randn(2, 4, 5, 3).astype(np.float32)
192+
const_3 = helper.make_tensor("const_3", TensorProto.FLOAT, (2, 4, 5, 3), const_3_val.flatten())
193193
const_3_node = helper.make_node("Constant", [], ["const_3"], value=const_3, name="const_3")
194194

195195
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
@@ -212,8 +212,8 @@ def test_transpose_max_input_non_const(self):
212212
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1,), const_1_val)
213213
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
214214

215-
const_2_val = np.random.randn(2, 4, 5, 3).astype(np.float32).reshape(120).tolist()
216-
const_2 = helper.make_tensor("const_2", TensorProto.FLOAT, (2, 4, 5, 3), const_2_val)
215+
const_2_val = np.random.randn(2, 4, 5, 3).astype(np.float32)
216+
const_2 = helper.make_tensor("const_2", TensorProto.FLOAT, (2, 4, 5, 3), const_2_val.flatten())
217217
const_2_node = helper.make_node("Constant", [], ["const_2"], value=const_2, name="const_2")
218218

219219
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
@@ -465,8 +465,8 @@ def test_transpose_add_with_input_non_const(self):
465465
model_proto, remaining_transpose_num=0)
466466

467467
def test_transpose_add_with_input_const(self):
468-
const_1_val = np.random.randn(1, 3, 3, 1).astype(np.float32).reshape(9).tolist()
469-
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1, 3, 3, 1), const_1_val)
468+
const_1_val = np.random.randn(1, 3, 3, 1).astype(np.float32)
469+
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1, 3, 3, 1), const_1_val.flatten())
470470
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
471471

472472
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
@@ -485,8 +485,9 @@ def test_transpose_add_with_input_const(self):
485485
model_proto, remaining_transpose_num=0)
486486

487487
def test_transpose_add_with_conv_1(self):
488-
const_b_val = np.random.randn(1, 1, 1, 16).astype(np.float32).reshape(16).tolist()
489-
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 1, 1, 16), const_b_val)
488+
# case where bias's dim is 1D and can be merged into Conv
489+
const_b_val = np.random.randn(1, 1, 1, 16).astype(np.float32)
490+
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 1, 1, 16), const_b_val.flatten())
490491
const_b_node = helper.make_node("Constant", [], ["const_b"], value=const_b, name="const_b")
491492

492493
node0 = helper.make_node("Conv", ["x", "W"], ["X"], name="conv", pads=[0, 0, 0, 0])
@@ -508,8 +509,10 @@ def test_transpose_add_with_conv_1(self):
508509
model_proto, remaining_transpose_num=0)
509510

510511
def test_transpose_add_with_conv_2(self):
511-
const_b_val = np.random.randn(1, 3, 3, 1).astype(np.float32).reshape(9).tolist()
512-
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 3, 3, 1), const_b_val)
512+
# case where bias's dim is not 1D and can't be merged into Conv
513+
# add handler just remove the transpose around Add node
514+
const_b_val = np.random.randn(1, 3, 3, 1).astype(np.float32)
515+
const_b = helper.make_tensor("const_b", TensorProto.FLOAT, (1, 3, 3, 1), const_b_val.flatten())
513516
const_b_node = helper.make_node("Constant", [], ["const_b"], value=const_b, name="const_b")
514517

515518
node0 = helper.make_node("Conv", ["x", "W"], ["X"], name="conv", pads=[0, 0, 0, 0])

0 commit comments

Comments
 (0)