-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_autovae.py
More file actions
160 lines (139 loc) · 6.99 KB
/
test_autovae.py
File metadata and controls
160 lines (139 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=C0115,C0116,C0301
import unittest
import torch
from autovae import VAEGenerator
class TestVAEGenerator(unittest.TestCase):
"""Unit tests for the VAEGenerator class."""
def setUp(self):
# Common setup for tests
self.input_resolution = 1024
self.compression_ratio = 8
self.generator = VAEGenerator(input_resolution=self.input_resolution, compression_ratio=self.compression_ratio)
def test_initialization_valid(self):
"""Test that valid initialization parameters set the correct properties."""
generator = VAEGenerator(input_resolution=1024, compression_ratio=8)
self.assertEqual(generator.input_resolution, 1024)
self.assertEqual(generator.compression_ratio, 8)
generator = VAEGenerator(input_resolution=2048, compression_ratio=16)
self.assertEqual(generator.input_resolution, 2048)
self.assertEqual(generator.compression_ratio, 16)
def test_initialization_invalid(self):
"""Test that invalid initialization parameters raise an error."""
with self.assertRaises(NotImplementedError):
VAEGenerator(input_resolution=4096, compression_ratio=16)
def test_generate_input(self):
"""Test that _generate_input produces a tensor with the correct shape and device."""
input_tensor = self.generator._generate_input()
expected_shape = (1, 3, self.input_resolution, self.input_resolution)
self.assertEqual(input_tensor.shape, expected_shape)
self.assertEqual(input_tensor.dtype, torch.float16)
self.assertEqual(input_tensor.device.type, "cuda")
def test_count_parameters(self):
"""Test that _count_parameters correctly counts model parameters."""
model = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 5))
expected_param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
param_count = self.generator._count_parameters(model)
self.assertEqual(param_count, expected_param_count)
def test_load_base_json_skeleton(self):
"""Test that _load_base_json_skeleton returns the correct skeleton."""
skeleton = self.generator._load_base_json_skeleton()
expected_keys = {
"_class_name",
"_diffusers_version",
"_name_or_path",
"act_fn",
"block_out_channels",
"down_block_types",
"force_upcast",
"in_channels",
"latent_channels",
"layers_per_block",
"norm_num_groups",
"out_channels",
"sample_size",
"scaling_factor",
"up_block_types",
}
self.assertEqual(set(skeleton.keys()), expected_keys)
def test_generate_all_combinations(self):
"""Test that _generate_all_combinations generates all possible combinations."""
attr = {"layers_per_block": [1, 2], "latent_channels": [4, 8]}
combinations = self.generator._generate_all_combinations(attr)
expected_combinations = [
{"layers_per_block": 1, "latent_channels": 4},
{"layers_per_block": 1, "latent_channels": 8},
{"layers_per_block": 2, "latent_channels": 4},
{"layers_per_block": 2, "latent_channels": 8},
]
self.assertEqual(len(combinations), len(expected_combinations))
for combo in expected_combinations:
self.assertIn(combo, combinations)
def test_assign_attributes(self):
"""Test that _assign_attributes correctly assigns attributes to the skeleton."""
choice = {
"down_block_types": ["DownEncoderBlock2D"] * 4,
"up_block_types": ["UpDecoderBlock2D"] * 4,
"block_out_channels": [64, 128, 256, 512],
"layers_per_block": 2,
"latent_channels": 16,
}
skeleton = self.generator._assign_attributes(choice)
self.assertEqual(skeleton["down_block_types"], choice["down_block_types"])
self.assertEqual(skeleton["up_block_types"], choice["up_block_types"])
self.assertEqual(skeleton["block_out_channels"], choice["block_out_channels"])
self.assertEqual(skeleton["layers_per_block"], choice["layers_per_block"])
self.assertEqual(skeleton["latent_channels"], choice["latent_channels"])
def test_search_space_16x1024(self):
"""Test that _search_space_16x1024 returns the correct search space."""
search_space = self.generator._search_space_16x1024()
expected_keys = {
"down_block_types",
"up_block_types",
"block_out_channels",
"layers_per_block",
"latent_channels",
}
self.assertEqual(set(search_space.keys()), expected_keys)
self.assertTrue(all(isinstance(v, list) for v in search_space.values()))
def test_sort_data_in_place(self):
"""Test that _sort_data_in_place correctly sorts data based on the specified mode."""
data = [
{"param_diff": 10, "cuda_mem_diff": 100},
{"param_diff": 5, "cuda_mem_diff": 50},
{"param_diff": -3, "cuda_mem_diff": 30},
{"param_diff": 7, "cuda_mem_diff": 70},
]
# Test sorting by absolute parameter difference
self.generator._sort_data_in_place(data, mode="abs_param_diff")
expected_order_param = [-3, 5, 7, 10]
actual_order_param = [item["param_diff"] for item in data]
self.assertEqual(actual_order_param, expected_order_param)
# Test sorting by absolute CUDA memory difference
self.generator._sort_data_in_place(data, mode="abs_cuda_mem_diff")
expected_order_mem = [30, 50, 70, 100]
actual_order_mem = [item["cuda_mem_diff"] for item in data]
self.assertEqual(actual_order_mem, expected_order_mem)
# Test sorting by mean squared error (MSE)
self.generator._sort_data_in_place(data, mode="mse")
expected_order_mse = [-3, 5, 7, 10] # Computed based on MSE values
actual_order_mse = [item["param_diff"] for item in data]
self.assertEqual(actual_order_mse, expected_order_mse)
def test_search_for_target_vae_invalid(self):
"""Test that search_for_target_vae raises an error when no budget is specified."""
with self.assertRaises(ValueError):
self.generator.search_for_target_vae(parameters_budget=0, cuda_max_mem=0)
if __name__ == "__main__":
unittest.main()