|
18 | 18 |
|
19 | 19 | from monai.networks import eval_mode |
20 | 20 | from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding |
21 | | -from tests.test_utils import test_script_save |
| 21 | +from tests.test_utils import dict_product, test_script_save |
22 | 22 |
|
23 | 23 | TEST_CASE_RES_BASIC_BLOCK = [] |
24 | | -for spatial_dims in range(2, 4): |
25 | | - for kernel_size in [1, 3]: |
26 | | - for stride in [1, 2]: |
27 | | - for norm_name in [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]: |
28 | | - for in_size in [15, 16]: |
29 | | - padding = get_padding(kernel_size, stride) |
30 | | - if not isinstance(padding, int): |
31 | | - padding = padding[0] |
32 | | - out_size = int((in_size + 2 * padding - kernel_size) / stride) + 1 |
33 | | - test_case = [ |
34 | | - { |
35 | | - "spatial_dims": spatial_dims, |
36 | | - "in_channels": 16, |
37 | | - "out_channels": 16, |
38 | | - "kernel_size": kernel_size, |
39 | | - "norm_name": norm_name, |
40 | | - "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), |
41 | | - "stride": stride, |
42 | | - }, |
43 | | - (1, 16, *([in_size] * spatial_dims)), |
44 | | - (1, 16, *([out_size] * spatial_dims)), |
45 | | - ] |
46 | | - TEST_CASE_RES_BASIC_BLOCK.append(test_case) |
| 24 | +for params in dict_product( |
| 25 | + spatial_dims=range(2, 4), |
| 26 | + kernel_size=[1, 3], |
| 27 | + stride=[1, 2], |
| 28 | + norm_name=[("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"], |
| 29 | + in_size=[15, 16], |
| 30 | +): |
| 31 | + padding = get_padding(params["kernel_size"], params["stride"]) |
| 32 | + if not isinstance(padding, int): |
| 33 | + padding = padding[0] |
| 34 | + out_size = int((params["in_size"] + 2 * padding - params["kernel_size"]) / params["stride"]) + 1 |
| 35 | + test_case = [ |
| 36 | + { |
| 37 | + **{k: v for k, v in params.items() if k != "in_size"}, |
| 38 | + "in_channels": 16, |
| 39 | + "out_channels": 16, |
| 40 | + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), |
| 41 | + }, |
| 42 | + (1, 16, *([params["in_size"]] * params["spatial_dims"])), |
| 43 | + (1, 16, *([out_size] * params["spatial_dims"])), |
| 44 | + ] |
| 45 | + TEST_CASE_RES_BASIC_BLOCK.append(test_case) |
47 | 46 |
|
48 | 47 | TEST_UP_BLOCK = [] |
49 | 48 | in_channels, out_channels = 4, 2 |
50 | | -for spatial_dims in range(2, 4): |
51 | | - for kernel_size in [1, 3]: |
52 | | - for stride in [1, 2]: |
53 | | - for norm_name in ["batch", "instance"]: |
54 | | - for in_size in [15, 16]: |
55 | | - for trans_bias in [True, False]: |
56 | | - out_size = in_size * stride |
57 | | - test_case = [ |
58 | | - { |
59 | | - "spatial_dims": spatial_dims, |
60 | | - "in_channels": in_channels, |
61 | | - "out_channels": out_channels, |
62 | | - "kernel_size": kernel_size, |
63 | | - "norm_name": norm_name, |
64 | | - "stride": stride, |
65 | | - "upsample_kernel_size": stride, |
66 | | - "trans_bias": trans_bias, |
67 | | - }, |
68 | | - (1, in_channels, *([in_size] * spatial_dims)), |
69 | | - (1, out_channels, *([out_size] * spatial_dims)), |
70 | | - (1, out_channels, *([in_size * stride] * spatial_dims)), |
71 | | - ] |
72 | | - TEST_UP_BLOCK.append(test_case) |
| 49 | +for params in dict_product( |
| 50 | + spatial_dims=range(2, 4), |
| 51 | + kernel_size=[1, 3], |
| 52 | + stride=[1, 2], |
| 53 | + norm_name=["batch", "instance"], |
| 54 | + in_size=[15, 16], |
| 55 | + trans_bias=[True, False], |
| 56 | +): |
| 57 | + out_size = params["in_size"] * params["stride"] |
| 58 | + test_case = [ |
| 59 | + { |
| 60 | + **{k: v for k, v in params.items() if k != "in_size"}, |
| 61 | + "in_channels": in_channels, |
| 62 | + "out_channels": out_channels, |
| 63 | + "upsample_kernel_size": params["stride"], |
| 64 | + }, |
| 65 | + (1, in_channels, *([params["in_size"]] * params["spatial_dims"])), |
| 66 | + (1, out_channels, *([out_size] * params["spatial_dims"])), |
| 67 | + (1, out_channels, *([params["in_size"] * params["stride"]] * params["spatial_dims"])), |
| 68 | + ] |
| 69 | + TEST_UP_BLOCK.append(test_case) |
73 | 70 |
|
74 | 71 |
|
75 | 72 | class TestResBasicBlock(unittest.TestCase): |
|
0 commit comments