|
81 | 81 | SIZE = 100 |
82 | 82 |
|
83 | 83 |
|
84 | | -class Conv2dBatchNorm2d_Fixed(nn.Module): |
85 | | - def __init__(self, in_channels, out_channels, **kwargs): |
86 | | - super(Conv2dBatchNorm2d_Fixed, self).__init__() |
87 | | - seed = 2018 |
88 | | - torch.manual_seed(seed) |
89 | | - self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) |
90 | | - self.bn = nn.BatchNorm2d(out_channels, eps=0.001) |
91 | | - |
92 | | - def forward(self, x): |
93 | | - return self.bn(self.conv(x)) |
| 84 | +conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} |
| 85 | +bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} |
94 | 86 |
|
95 | | -class Conv3dBatchNorm3d_Fixed(nn.Module): |
96 | | - def __init__(self, in_channels, out_channels, **kwargs): |
97 | | - super(Conv3dBatchNorm3d_Fixed, self).__init__() |
| 87 | +class ConvBatchNorm_Fixed(nn.Module): |
| 88 | + def __init__(self, dim, in_channels, out_channels, **kwargs): |
| 89 | + super(ConvBatchNorm_Fixed, self).__init__() |
98 | 90 | seed = 2018 |
99 | 91 | torch.manual_seed(seed) |
100 | | - self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) |
101 | | - self.bn = nn.BatchNorm3d(out_channels, eps=0.001) |
| 92 | + self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs) |
| 93 | + self.bn = bn_module[dim](out_channels, eps=0.001) |
102 | 94 |
|
103 | 95 | def forward(self, x): |
104 | 96 | return self.bn(self.conv(x)) |
105 | 97 |
|
106 | | -class Conv2dRelu_Fixed(nn.Module): |
107 | | - def __init__(self, in_channels, out_channels, **kwargs): |
108 | | - super(Conv2dRelu_Fixed, self).__init__() |
| 98 | +class ConvRelu_Fixed(nn.Module): |
| 99 | + def __init__(self, dim, in_channels, out_channels, **kwargs): |
| 100 | + super(ConvRelu_Fixed, self).__init__() |
109 | 101 | seed = 2018 |
110 | 102 | torch.manual_seed(seed) |
111 | | - self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) |
| 103 | + self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs) |
112 | 104 |
|
113 | 105 | def forward(self, x): |
114 | 106 | return F.relu(self.conv(x), inplace=True) |
115 | 107 |
|
116 | | -class Conv3dRelu_Fixed(nn.Module): |
117 | | - def __init__(self, in_channels, out_channels, **kwargs): |
118 | | - super(Conv3dRelu_Fixed, self).__init__() |
119 | | - seed = 2018 |
120 | | - torch.manual_seed(seed) |
121 | | - self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) |
122 | | - |
123 | | - def forward(self, x): |
124 | | - return F.relu(self.conv(x), inplace=True) |
125 | | - |
126 | | -class Conv2dSum(nn.Module): |
127 | | - def __init__(self, in_channels, out_channels, **kwargs): |
128 | | - super(Conv2dSum, self).__init__() |
129 | | - seed = 2018 |
130 | | - torch.manual_seed(seed) |
131 | | - self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) |
132 | | - self.conv1 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) |
133 | | - |
134 | | - def forward(self, x): |
135 | | - a = self.conv(x) |
136 | | - b = self.conv1(x) |
137 | | - return a+b |
138 | | - |
139 | | -class Conv3dSum(nn.Module): |
140 | | - def __init__(self, in_channels, out_channels, **kwargs): |
141 | | - super(Conv3dSum, self).__init__() |
| 108 | +class ConvSum(nn.Module): |
| 109 | + def __init__(self, dim, in_channels, out_channels, **kwargs): |
| 110 | + super(ConvSum, self).__init__() |
142 | 111 | seed = 2018 |
143 | 112 | torch.manual_seed(seed) |
144 | | - self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) |
145 | | - self.conv1 = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) |
| 113 | + self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs) |
| 114 | + self.conv1 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs) |
146 | 115 |
|
147 | 116 | def forward(self, x): |
148 | 117 | a = self.conv(x) |
149 | 118 | b = self.conv1(x) |
150 | 119 | return a+b |
151 | 120 |
|
152 | | -class CascadedConv2dBnSumRelu(nn.Module): |
153 | | - def __init__(self, in_channels, mid_channels, out_channels, **kwargs): |
154 | | - super(CascadedConv2dBnSumRelu, self).__init__() |
155 | | - torch.manual_seed(2018) |
156 | | - self.conv = nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs) |
157 | | - self.conv1 = nn.Conv2d( |
158 | | - mid_channels, out_channels, bias=False, padding=1, **kwargs) |
159 | | - self.conv2 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) |
160 | | - self.bn = nn.BatchNorm2d(mid_channels, eps=0.001) |
161 | | - self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001) |
162 | | - self.bn2 = nn.BatchNorm2d(out_channels, eps=0.001) |
163 | | - |
164 | | - def forward(self, x): |
165 | | - a = self.conv(x) |
166 | | - a = self.bn(a) |
167 | | - a = F.relu(a, inplace=True) |
168 | | - a = self.conv1(a) |
169 | | - a = self.bn1(a) |
170 | | - b = self.conv2(x) |
171 | | - b = self.bn2(b) |
172 | | - return F.relu(a.add_(b), inplace=True) |
173 | | - |
174 | | -class CascadedConv3dBnSumRelu(nn.Module): |
175 | | - def __init__(self, in_channels, mid_channels, out_channels, **kwargs): |
176 | | - super(CascadedConv3dBnSumRelu, self).__init__() |
| 121 | +class CascadedConvBnSumRelu(nn.Module): |
| 122 | + def __init__(self, dim, in_channels, mid_channels, out_channels, **kwargs): |
| 123 | + super(CascadedConvBnSumRelu, self).__init__() |
177 | 124 | torch.manual_seed(2018) |
178 | | - self.conv = nn.Conv3d(in_channels, mid_channels, bias=False, **kwargs) |
179 | | - self.conv1 = nn.Conv3d( |
| 125 | + self.conv = conv_module[dim](in_channels, mid_channels, bias=False, **kwargs) |
| 126 | + self.conv1 = conv_module[dim]( |
180 | 127 | mid_channels, out_channels, bias=False, padding=1, **kwargs) |
181 | | - self.conv2 = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) |
182 | | - self.bn = nn.BatchNorm3d(mid_channels, eps=0.001) |
183 | | - self.bn1 = nn.BatchNorm3d(out_channels, eps=0.001) |
184 | | - self.bn2 = nn.BatchNorm3d(out_channels, eps=0.001) |
| 128 | + self.conv2 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs) |
| 129 | + self.bn = bn_module[dim](mid_channels, eps=0.001) |
| 130 | + self.bn1 = bn_module[dim](out_channels, eps=0.001) |
| 131 | + self.bn2 = bn_module[dim](out_channels, eps=0.001) |
185 | 132 |
|
186 | 133 | def forward(self, x): |
187 | 134 | a = self.conv(x) |
@@ -280,93 +227,93 @@ def _test_output_bf16(self, model, x, kind=None, prec=None): |
280 | 227 |
|
281 | 228 | def test_output_conv_bn_2d(self): |
282 | 229 | self._test_output( |
283 | | - Conv2dBatchNorm2d_Fixed(3, 32, kernel_size=3, stride=1), |
| 230 | + ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1), |
284 | 231 | torch.randn(32, 3, 224, 224), |
285 | 232 | kind="aten::conv2d") |
286 | 233 | self._test_output_bf16( |
287 | | - Conv2dBatchNorm2d_Fixed(3, 32, kernel_size=3, stride=1), |
| 234 | + ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1), |
288 | 235 | torch.randn(32, 3, 224, 224), |
289 | 236 | kind="aten::conv2d", |
290 | 237 | prec=0.02) |
291 | 238 |
|
292 | 239 |
|
293 | 240 | def test_output_conv_bn_3d(self): |
294 | 241 | self._test_output( |
295 | | - Conv3dBatchNorm3d_Fixed(3, 32, kernel_size=3, stride=1), |
| 242 | + ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1), |
296 | 243 | torch.randn(32, 3, 112, 112, 112), |
297 | 244 | kind="aten::conv3d") |
298 | 245 | self._test_output_bf16( |
299 | | - Conv3dBatchNorm3d_Fixed(3, 32, kernel_size=3, stride=1), |
| 246 | + ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1), |
300 | 247 | torch.randn(32, 3, 112, 112, 112), |
301 | 248 | kind="aten::conv3d", |
302 | 249 | prec=0.02) |
303 | 250 |
|
304 | 251 |
|
305 | 252 | def test_output_conv_relu_2d(self): |
306 | 253 | self._test_output( |
307 | | - Conv2dRelu_Fixed(3, 32, kernel_size=3, stride=1), |
| 254 | + ConvRelu_Fixed(2, 3, 32, kernel_size=3, stride=1), |
308 | 255 | torch.randn(32, 3, 224, 224), |
309 | 256 | kind="ipex::conv2d_relu") |
310 | 257 | self._test_output_bf16( |
311 | | - Conv2dRelu_Fixed(3, 32, kernel_size=3, stride=1), |
| 258 | + ConvRelu_Fixed(2, 3, 32, kernel_size=3, stride=1), |
312 | 259 | torch.randn(32, 3, 224, 224), |
313 | 260 | kind="ipex::conv2d_relu") |
314 | 261 |
|
315 | 262 |
|
316 | 263 | def test_output_conv_relu_3d(self): |
317 | 264 | self._test_output( |
318 | | - Conv3dRelu_Fixed(3, 32, kernel_size=3, stride=1), |
| 265 | + ConvRelu_Fixed(3, 3, 32, kernel_size=3, stride=1), |
319 | 266 | torch.randn(32, 3, 112, 112, 112), |
320 | 267 | kind="ipex::conv3d_relu") |
321 | 268 | self._test_output_bf16( |
322 | | - Conv3dRelu_Fixed(3, 32, kernel_size=3, stride=1), |
| 269 | + ConvRelu_Fixed(3, 3, 32, kernel_size=3, stride=1), |
323 | 270 | torch.randn(32, 3, 112, 112, 112), |
324 | 271 | kind="ipex::conv3d_relu") |
325 | 272 |
|
326 | 273 |
|
327 | 274 | def test_output_conv_sum_2d(self): |
328 | 275 | self._test_output( |
329 | | - Conv2dSum(3, 32, kernel_size=3, stride=1), |
| 276 | + ConvSum(2, 3, 32, kernel_size=3, stride=1), |
330 | 277 | torch.randn(32, 3, 224, 224), |
331 | 278 | kind="ipex::conv2d_sum") |
332 | 279 | self._test_output_bf16( |
333 | | - Conv2dSum(3, 32, kernel_size=3, stride=1), |
| 280 | + ConvSum(2, 3, 32, kernel_size=3, stride=1), |
334 | 281 | torch.randn(32, 3, 224, 224), |
335 | 282 | kind="ipex::conv2d_sum", |
336 | 283 | prec=0.04) |
337 | 284 |
|
338 | 285 |
|
339 | 286 | def test_output_conv_sum_3d(self): |
340 | 287 | self._test_output( |
341 | | - Conv3dSum(3, 32, kernel_size=3, stride=1), |
| 288 | + ConvSum(3, 3, 32, kernel_size=3, stride=1), |
342 | 289 | torch.randn(32, 3, 112, 112, 112), |
343 | 290 | kind="ipex::conv3d_sum") |
344 | 291 | self._test_output_bf16( |
345 | | - Conv3dSum(3, 32, kernel_size=3, stride=1), |
| 292 | + ConvSum(3, 3, 32, kernel_size=3, stride=1), |
346 | 293 | torch.randn(32, 3, 112, 112, 112), |
347 | 294 | kind="ipex::conv3d_sum", |
348 | 295 | prec=0.04) |
349 | 296 |
|
350 | 297 |
|
351 | 298 | def test_output_cascaded_conv_bn_sum_relu_2d(self): |
352 | 299 | self._test_output( |
353 | | - CascadedConv2dBnSumRelu(3, 64, 32, kernel_size=3, stride=1), |
| 300 | + CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1), |
354 | 301 | torch.rand(32, 3, 224, 224), |
355 | 302 | kind="ipex::conv2d_sum_relu") |
356 | 303 | self._test_output_bf16( |
357 | | - CascadedConv2dBnSumRelu(3, 64, 32, kernel_size=3, stride=1), |
| 304 | + CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1), |
358 | 305 | torch.rand(32, 3, 224, 224), |
359 | 306 | kind="ipex::conv2d_sum_relu", |
360 | 307 | prec=0.02) |
361 | 308 |
|
362 | 309 |
|
363 | 310 | def test_output_cascaded_conv_bn_sum_relu_3d(self): |
364 | 311 | self._test_output( |
365 | | - CascadedConv3dBnSumRelu(3, 64, 32, kernel_size=3, stride=1), |
| 312 | + CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1), |
366 | 313 | torch.rand(32, 3, 112, 112, 112), |
367 | 314 | kind="ipex::conv3d_sum_relu") |
368 | 315 | self._test_output_bf16( |
369 | | - CascadedConv3dBnSumRelu(3, 64, 32, kernel_size=3, stride=1), |
| 316 | + CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1), |
370 | 317 | torch.rand(32, 3, 112, 112, 112), |
371 | 318 | kind="ipex::conv3d_sum_relu", |
372 | 319 | prec=0.02) |
|
0 commit comments