|
36 | 36 | _INPUT1 = "input1:0"
|
37 | 37 | _TFINPUT2 = "input2"
|
38 | 38 | _INPUT2 = "input2:0"
|
| 39 | +_TFINPUT3 = "input3" |
| 40 | +_INPUT3 = "input3:0" |
39 | 41 | _TFOUTPUT = "output"
|
40 | 42 | _OUTPUT = "output:0"
|
41 | 43 | _TFOUTPUT1 = "output1"
|
@@ -2470,6 +2472,74 @@ def test_batch_to_spacend(self):
|
2470 | 2472 | _ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
|
2471 | 2473 | self._run_test_case([_OUTPUT], {_INPUT: input_val})
|
2472 | 2474 |
|
| 2475 | + @check_opset_min_version(11, "BatchToSpaceND") |
| 2476 | + def test_batch_to_spacend_non_const(self): |
| 2477 | + input_x_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32) # NHWC |
| 2478 | + block_shape_val = np.array([2, 2]).astype(np.int64) |
| 2479 | + crops_val = np.array([[1, 0], [2, 1]]).astype(np.int64) |
| 2480 | + input_x = tf.placeholder(dtype=tf.float32, shape=input_x_val.shape, name=_TFINPUT) |
| 2481 | + block_shape = tf.placeholder(dtype=tf.int64, shape=block_shape_val.shape, name=_TFINPUT1) |
| 2482 | + crops = tf.placeholder(dtype=tf.int64, shape=crops_val.shape, name=_TFINPUT2) |
| 2483 | + _ = tf.batch_to_space_nd(input_x, block_shape, crops, name=_TFOUTPUT) |
| 2484 | + self._run_test_case([_OUTPUT], {_INPUT: input_x_val, _INPUT1: block_shape_val, _INPUT2: crops_val}) |
| 2485 | + |
| 2486 | + @check_opset_min_version(11, "SpaceToBatchND") |
| 2487 | + def test_space_to_batchnd_non_const(self): |
| 2488 | + input_x_val = np.random.random_sample([40, 5, 7, 66]).astype(np.float32) # NHWC |
| 2489 | + block_size_val = np.array([2, 2]).astype(np.int64) |
| 2490 | + pad_val = np.array([[0, 1], [2, 1]]).astype(np.int64) |
| 2491 | + input_x = tf.placeholder(dtype=tf.float32, shape=input_x_val.shape, name=_TFINPUT) |
| 2492 | + block_size = tf.placeholder(dtype=tf.int64, shape=block_size_val.shape, name=_TFINPUT1) |
| 2493 | + pad = tf.placeholder(dtype=tf.int64, shape=pad_val.shape, name=_TFINPUT2) |
| 2494 | + _ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT) |
| 2495 | + self._run_test_case([_OUTPUT], {_INPUT: input_x_val, _INPUT1: block_size_val, _INPUT2: pad_val}) |
| 2496 | + |
| 2497 | + @check_opset_min_version(11, "CropAndResize") |
| 2498 | + def test_crop_and_resize_linear(self): |
| 2499 | + input_x_val = np.random.randint(low=0, high=256, size=[2, 36, 36, 3]).astype(np.float32) # NHWC |
| 2500 | + boxes_val = np.array([[0.5, 0.7, 0.7, 0.9], [0.2, 0.4, 0.4, 0.6]]).astype(np.float32) |
| 2501 | + box_ind_val = np.array([1, 0]).astype(np.int32) |
| 2502 | + corp_size_val = np.array([20, 20]).astype(np.int32) |
| 2503 | + input_x = tf.placeholder(dtype=tf.float32, shape=input_x_val.shape, name=_TFINPUT) |
| 2504 | + boxes = tf.placeholder(dtype=tf.float32, shape=boxes_val.shape, name=_TFINPUT1) |
| 2505 | + box_ind = tf.placeholder(dtype=tf.int32, shape=box_ind_val.shape, name=_TFINPUT2) |
| 2506 | + corp_size = tf.placeholder(dtype=tf.int32, shape=corp_size_val.shape, name=_TFINPUT3) |
| 2507 | + _ = tf.image.crop_and_resize(input_x, boxes, box_ind, corp_size, name=_TFOUTPUT, method='bilinear') |
| 2508 | + self._run_test_case([_OUTPUT], |
| 2509 | + {_INPUT: input_x_val, _INPUT1: boxes_val, _INPUT2: box_ind_val, _INPUT3: corp_size_val}, |
| 2510 | + rtol=1e-05, atol=1e-04) |
| 2511 | + |
| 2512 | + @check_tf_min_version("1.9") |
| 2513 | + @check_opset_min_version(11, "CropAndResize") |
| 2514 | + def test_crop_and_resize_nearest(self): |
| 2515 | + input_x_val = np.random.randint(low=0, high=256, size=[1, 36, 36, 3]).astype(np.float32) # NHWC |
| 2516 | + boxes_val = np.array([[0.2, 0.4, 0.6, 0.8]]).astype(np.float32) |
| 2517 | + box_ind_val = np.array([0]).astype(np.int32) |
| 2518 | + corp_size_val = np.array([30, 30]).astype(np.int32) |
| 2519 | + input_x = tf.placeholder(dtype=tf.float32, shape=input_x_val.shape, name=_TFINPUT) |
| 2520 | + boxes = tf.placeholder(dtype=tf.float32, shape=boxes_val.shape, name=_TFINPUT1) |
| 2521 | + box_ind = tf.placeholder(dtype=tf.int32, shape=box_ind_val.shape, name=_TFINPUT2) |
| 2522 | + corp_size = tf.placeholder(dtype=tf.int32, shape=corp_size_val.shape, name=_TFINPUT3) |
| 2523 | + _ = tf.image.crop_and_resize(input_x, boxes, box_ind, corp_size, name=_TFOUTPUT, method='nearest') |
| 2524 | + self._run_test_case([_OUTPUT], |
| 2525 | + {_INPUT: input_x_val, _INPUT1: boxes_val, _INPUT2: box_ind_val, _INPUT3: corp_size_val}, |
| 2526 | + rtol=1e-05, atol=1e-04) |
| 2527 | + |
| 2528 | + @check_opset_min_version(11, "CropAndResize") |
| 2529 | + def test_crop_and_resize_extrapolation(self): |
| 2530 | + input_x_val = np.random.randint(low=0, high=256, size=[1, 36, 36, 3]).astype(np.float32) # NHWC |
| 2531 | + boxes_val = np.array([[0.2, 0.4, 1.2, 1.4]]).astype(np.float32) |
| 2532 | + box_ind_val = np.array([0]).astype(np.int32) |
| 2533 | + corp_size_val = np.array([40, 40]).astype(np.int32) |
| 2534 | + input_x = tf.placeholder(dtype=tf.float32, shape=input_x_val.shape, name=_TFINPUT) |
| 2535 | + boxes = tf.placeholder(dtype=tf.float32, shape=boxes_val.shape, name=_TFINPUT1) |
| 2536 | + box_ind = tf.placeholder(dtype=tf.int32, shape=box_ind_val.shape, name=_TFINPUT2) |
| 2537 | + corp_size = tf.placeholder(dtype=tf.int32, shape=corp_size_val.shape, name=_TFINPUT3) |
| 2538 | + _ = tf.image.crop_and_resize(input_x, boxes, box_ind, corp_size, name=_TFOUTPUT, extrapolation_value=1.0) |
| 2539 | + self._run_test_case([_OUTPUT], |
| 2540 | + {_INPUT: input_x_val, _INPUT1: boxes_val, _INPUT2: box_ind_val, _INPUT3: corp_size_val}, |
| 2541 | + rtol=1e-04, atol=1e-03) |
| 2542 | + |
2473 | 2543 | def test_batch_to_space3d(self):
|
2474 | 2544 | block_size = [2, 2]
|
2475 | 2545 | crop = [[0, 1], [2, 1]]
|
@@ -2828,9 +2898,8 @@ def test_unique(self):
|
2828 | 2898 | _ = tf.identity(x1_, name=_TFOUTPUT)
|
2829 | 2899 | _ = tf.identity(x2_, name=_TFOUTPUT1)
|
2830 | 2900 | # FIXME: indices in onnx are not the same as in tensorflow so don't check for now
|
2831 |
| - #self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}) |
| 2901 | + # self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}) |
2832 | 2902 | self._run_test_case([_OUTPUT], {_INPUT: x_val})
|
2833 | 2903 |
|
2834 |
| - |
2835 | 2904 | if __name__ == '__main__':
|
2836 | 2905 | unittest_main()
|
0 commit comments