|
10 | 10 | from functools import reduce |
11 | 11 | import copy |
12 | 12 | import sys |
| 13 | +import itertools |
13 | 14 | import torch |
14 | 15 | import intel_pytorch_extension as ipex |
15 | 16 |
|
@@ -187,76 +188,134 @@ def test_Conv2d_backward(self): |
187 | 188 | self.assertEqual(in_man_bf16.grad.float(), in_auto_mix.grad.float()) |
188 | 189 |
|
189 | 190 | class TestDeconv(TestCase): |
190 | | - def test_Deconv2d_with_cpu(self): |
| 191 | + def _deconv_params_list(self): |
| 192 | + params_dict = { |
| 193 | + "input_height": [8], |
| 194 | + "input_width": [8], |
| 195 | + "input_depth": [8], |
| 196 | + "input_channel_per_group": [10], |
| 197 | + "output_channel_per_group": [10], |
| 198 | + "kernel_size": [3, 4], |
| 199 | + "bias": [False, True], |
| 200 | + "stride": [2], # [1, 2] |
| 201 | + "padding": [1, 2], |
| 202 | + "output_padding": [2], |
| 203 | + "groups": [1, 2], |
| 204 | + "dilation": [1, 3, 4], |
| 205 | + } |
| 206 | + |
| 207 | + params_list = [] |
| 208 | + |
| 209 | + for key, value in params_dict.items(): |
| 210 | + params_list.append(value) |
| 211 | + return params_list |
| 212 | + |
| 213 | + def _test_deconv(self, dims): |
191 | 214 | rand_seed = int(get_rand_seed()) |
192 | 215 | print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) |
193 | 216 | torch.manual_seed(rand_seed) |
194 | 217 |
|
195 | | - _deconv = torch.nn.ConvTranspose2d(2, 3, (3, 3)) |
| 218 | + params_list = self._deconv_params_list() |
196 | 219 |
|
197 | | - deconv_man_bf16 = copy.deepcopy(_deconv).to(device=device).to(torch.bfloat16) |
198 | | - deconv_auto_mix = copy.deepcopy(_deconv).to(device=device) |
199 | | - deconv_auto_mix_train =copy.deepcopy(_deconv).to(device=device) |
| 220 | + for input_width, input_height, input_depth, input_channel_per_group, output_channel_per_group, kernel_size, bias, stride, padding, output_padding, groups, dilation in itertools.product(*params_list): |
| 221 | + if (output_padding < stride or output_padding < dilation) \ |
| 222 | + and ((input_height - 1) * stride - 2 * padding + dilation * (kernel_size -1 ) + output_padding + 1 > 0) \ |
| 223 | + and ((input_width - 1) * stride - 2 * padding + dilation * (kernel_size -1 ) + output_padding + 1 > 0) \ |
| 224 | + and ((input_depth - 1) * stride - 2 * padding + dilation * (kernel_size -1 ) + output_padding + 1 > 0): |
200 | 225 |
|
201 | | - _in_cpu = torch.rand((1, 2, 7, 7)) |
202 | | - in_auto_mix = _in_cpu.to(device=device) |
203 | | - in_man_bf16 = in_auto_mix.to(torch.bfloat16) |
204 | | - |
205 | | - in_auto_mix_train = _in_cpu.to(device=device) |
| 226 | + ic = input_channel_per_group * groups |
| 227 | + oc = output_channel_per_group * groups |
206 | 228 |
|
207 | | - res_cpu_fp32 = _deconv(_in_cpu) |
| 229 | + if dims == 2: |
| 230 | + module = torch.nn.ConvTranspose2d(ic, oc, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) |
| 231 | + x = torch.rand((2, ic, input_height, input_width)) |
| 232 | + elif dims == 3: |
| 233 | + module = torch.nn.ConvTranspose3d(ic, oc, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) |
| 234 | + x = torch.rand((2, ic, input_depth, input_height, input_width)) |
208 | 235 |
|
209 | | - with AutoDNNL(True), AutoMixPrecision(False): |
210 | | - res_man_bf16 = deconv_man_bf16(in_man_bf16) |
211 | | - self.assertEqual(res_man_bf16.dtype, torch.bfloat16) |
212 | | - self.assertEqual(res_cpu_fp32.bfloat16().float(), res_man_bf16, 1e-2) |
213 | | - |
214 | | - with AutoMixPrecision(True): |
215 | | - self.assertEqual(in_auto_mix.dtype, torch.float) |
216 | | - self.assertFalse(ipex.core.is_bf16_dil_tensor(in_auto_mix)) |
217 | | - res_auto_bf16 = deconv_auto_mix(in_auto_mix) |
218 | | - self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_bf16)) |
219 | | - self.assertEqual(res_man_bf16.float(), res_auto_bf16.float(), 1e-2) |
220 | | - |
221 | | - with AutoMixPrecision(True, train=True): |
222 | | - self.assertEqual(in_auto_mix_train.dtype, torch.float) |
223 | | - self.assertFalse(ipex.core.is_bf16_dil_tensor(in_auto_mix_train)) |
224 | | - res_auto_bf16_train = deconv_auto_mix_train(in_auto_mix_train) |
225 | | - self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_bf16_train)) |
226 | | - self.assertEqual(res_man_bf16.float(), res_auto_bf16_train.float(), 1e-2) |
227 | | - |
228 | | - def test_Deconv2d_backward(self): |
229 | | - rand_seed = int(get_rand_seed()) |
230 | | - print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) |
231 | | - torch.manual_seed(rand_seed) |
232 | | - |
233 | | - input = torch.rand(2, 10, 8, 8) |
234 | | - for bias in [True, False]: |
235 | | - _deconv = torch.nn.ConvTranspose2d(10, 10, |
236 | | - kernel_size=4, stride=2, bias=bias) |
237 | | - deconv_man_bf16 =copy.deepcopy(_deconv).to(device=device).to(torch.bfloat16) |
238 | | - deconv_auto_mix =copy.deepcopy(_deconv).to(device=device) |
239 | | - _in_cpu = input.clone().requires_grad_() |
240 | | - in_auto_mix = input.clone().to(device=device).requires_grad_() |
241 | | - in_man_bf16 = input.clone().to(device=device).to(torch.bfloat16).requires_grad_() |
242 | | - out_cpu = _deconv(_in_cpu).sum() |
243 | | - out_cpu.backward() |
244 | | - with AutoDNNL(True), AutoMixPrecision(False, train=True): |
245 | | - out_man_bf16 = deconv_man_bf16(in_man_bf16).sum() |
246 | | - out_man_bf16.backward() |
247 | | - self.assertEqual(in_man_bf16.grad.dtype, torch.bfloat16) |
248 | | - self.assertEqual(_in_cpu.grad.bfloat16().float(), in_man_bf16.grad, 1e-2) |
249 | | - |
250 | | - with AutoMixPrecision(True, train=True): |
251 | | - self.assertEqual(in_auto_mix.dtype, torch.float) |
252 | | - self.assertFalse(ipex.core.is_bf16_dil_tensor(in_auto_mix)) |
253 | | - out_auto_bf16 = deconv_auto_mix(in_auto_mix).sum() |
254 | | - out_auto_bf16.backward() |
255 | | - self.assertTrue(ipex.core.is_bf16_dil_tensor(in_auto_mix.grad)) |
256 | | - self.assertFalse(ipex.core.is_bf16_dil_tensor(deconv_auto_mix.weight)) |
257 | | - self.assertFalse(ipex.core.is_bf16_dil_tensor(deconv_auto_mix.weight.grad)) |
258 | | - self.assertEqual(in_man_bf16.grad.float(), in_auto_mix.grad.float()) |
| 236 | + module_auto_mix_infer = copy.deepcopy(module).to(device=device) |
| 237 | + module_auto_mix_train = copy.deepcopy(module).to(device=device) |
259 | 238 |
|
| 239 | + x_aten = x.clone().requires_grad_() |
| 240 | + x_auto_mix_infer = x.clone().to(device=device).requires_grad_() |
| 241 | + x_auto_mix_train = x.clone().to(device=device).requires_grad_() |
| 242 | + |
| 243 | + y_aten = module(x_aten) |
| 244 | + y_aten.sum().backward() |
| 245 | + |
| 246 | + with AutoDNNL(True), AutoMixPrecision(True, train=False): |
| 247 | + self.assertEqual(x_auto_mix_infer.dtype, torch.float) |
| 248 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_infer)) |
| 249 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_infer.weight)) |
| 250 | + if bias: |
| 251 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_infer.bias)) |
| 252 | + |
| 253 | + y_auto_mix_infer = module_auto_mix_infer(x_auto_mix_infer) |
| 254 | + y_auto_mix_infer.sum().backward() |
| 255 | + |
| 256 | + if padding - output_padding + stride > 0: |
| 257 | + self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_infer.grad)) |
| 258 | + self.assertTrue(ipex.core.is_bf16_dil_tensor(module_auto_mix_infer.weight)) |
| 259 | + if bias: |
| 260 | + self.assertTrue(ipex.core.is_bf16_dil_tensor(module_auto_mix_infer.bias)) |
| 261 | + |
| 262 | + self.assertEqual( |
| 263 | + y_aten, y_auto_mix_infer, 1e-2) |
| 264 | + |
| 265 | + # mkldnn does not support the case where: |
| 266 | + # padding - output_padding + stride <= 0 |
| 267 | + # while PyTorch supports this case, will fallback in this case |
| 268 | + else: |
| 269 | + # threshold because input has been reordered to bf16?? |
| 270 | + self.assertEqual( |
| 271 | + y_aten, y_auto_mix_infer, 4e-3) |
| 272 | + |
| 273 | + with AutoDNNL(True), AutoMixPrecision(True, train=True): |
| 274 | + self.assertEqual(x_auto_mix_train.dtype, torch.float) |
| 275 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_train)) |
| 276 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_train.weight)) |
| 277 | + if bias: |
| 278 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_train.bias)) |
| 279 | + |
| 280 | + y_auto_mix_train = module_auto_mix_train(x_auto_mix_train) |
| 281 | + y_auto_mix_train.sum().backward() |
| 282 | + |
| 283 | + if padding - output_padding + stride > 0: |
| 284 | + self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_train.grad)) |
| 285 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_train.weight)) |
| 286 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_train.weight.grad)) |
| 287 | + if bias: |
| 288 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_train.bias)) |
| 289 | + self.assertFalse(ipex.core.is_bf16_dil_tensor(module_auto_mix_train.bias.grad)) |
| 290 | + |
| 291 | + self.assertEqual( |
| 292 | + y_aten, y_auto_mix_train, 1e-2) |
| 293 | + self.assertEqual( |
| 294 | + module.weight.grad, module_auto_mix_train.weight.grad, 5e-1) |
| 295 | + self.assertEqual( |
| 296 | + x_aten.grad, x_auto_mix_train.grad, 1e-2) |
| 297 | + if bias: |
| 298 | + self.assertEqual(module.bias.grad, module_auto_mix_train.bias.grad) |
| 299 | + |
| 300 | + # mkldnn does not support the case where: |
| 301 | + # padding - output_padding + stride <= 0 |
| 302 | + # while PyTorch supports this case, will fallback in this case |
| 303 | + else: |
| 304 | + # threshold because input has been reordered to bf16?? |
| 305 | + self.assertEqual( |
| 306 | + y_aten, y_auto_mix_train, 3e-3) |
| 307 | + self.assertEqual( |
| 308 | + module.weight.grad, module_auto_mix_train.weight.grad, 2e-1) |
| 309 | + self.assertEqual( |
| 310 | + x_aten.grad, x_auto_mix_train.grad) |
| 311 | + if bias: |
| 312 | + self.assertEqual(module.bias.grad, module_auto_mix_train.bias.grad) |
| 313 | + |
| 314 | + def test_deconv2d(self): |
| 315 | + self._test_deconv(dims=2) |
| 316 | + |
| 317 | + def test_deconv3d(self): |
| 318 | + self._test_deconv(dims=3) |
260 | 319 |
|
261 | 320 | class TestBatchNorm(TestCase): |
262 | 321 | def test_batch_norm2d(self): |
@@ -1033,6 +1092,7 @@ def test_linear(self): |
1033 | 1092 |
|
1034 | 1093 | def test_linear_backward(self): |
1035 | 1094 | rand_seed = int(get_rand_seed()) |
| 1095 | + # rand_seed = 1600407821102260224 # self.assertEqual(_in_cpu.grad.bfloat16().float(), in_man_bf16.grad, 2e-2) AssertionError: tensor(0.0312) not less than or equal to 0.02 |
1036 | 1096 | print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) |
1037 | 1097 | torch.manual_seed(rand_seed) |
1038 | 1098 | in_features = torch.randint(3, 10, (1,)).item() |
|
0 commit comments