Skip to content

Commit 05e2754

Browse files
simple jit test cases (#83)
1 parent 04aa449 commit 05e2754

File tree

1 file changed

+41
-94
lines changed

1 file changed

+41
-94
lines changed

tests/cpu/test_jit.py

Lines changed: 41 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -81,107 +81,54 @@
8181
SIZE = 100
8282

8383

84-
class Conv2dBatchNorm2d_Fixed(nn.Module):
85-
def __init__(self, in_channels, out_channels, **kwargs):
86-
super(Conv2dBatchNorm2d_Fixed, self).__init__()
87-
seed = 2018
88-
torch.manual_seed(seed)
89-
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
90-
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
91-
92-
def forward(self, x):
93-
return self.bn(self.conv(x))
84+
conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
85+
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
9486

95-
class Conv3dBatchNorm3d_Fixed(nn.Module):
96-
def __init__(self, in_channels, out_channels, **kwargs):
97-
super(Conv3dBatchNorm3d_Fixed, self).__init__()
87+
class ConvBatchNorm_Fixed(nn.Module):
88+
def __init__(self, dim, in_channels, out_channels, **kwargs):
89+
super(ConvBatchNorm_Fixed, self).__init__()
9890
seed = 2018
9991
torch.manual_seed(seed)
100-
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
101-
self.bn = nn.BatchNorm3d(out_channels, eps=0.001)
92+
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
93+
self.bn = bn_module[dim](out_channels, eps=0.001)
10294

10395
def forward(self, x):
10496
return self.bn(self.conv(x))
10597

106-
class Conv2dRelu_Fixed(nn.Module):
107-
def __init__(self, in_channels, out_channels, **kwargs):
108-
super(Conv2dRelu_Fixed, self).__init__()
98+
class ConvRelu_Fixed(nn.Module):
99+
def __init__(self, dim, in_channels, out_channels, **kwargs):
100+
super(ConvRelu_Fixed, self).__init__()
109101
seed = 2018
110102
torch.manual_seed(seed)
111-
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
103+
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
112104

113105
def forward(self, x):
114106
return F.relu(self.conv(x), inplace=True)
115107

116-
class Conv3dRelu_Fixed(nn.Module):
117-
def __init__(self, in_channels, out_channels, **kwargs):
118-
super(Conv3dRelu_Fixed, self).__init__()
119-
seed = 2018
120-
torch.manual_seed(seed)
121-
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
122-
123-
def forward(self, x):
124-
return F.relu(self.conv(x), inplace=True)
125-
126-
class Conv2dSum(nn.Module):
127-
def __init__(self, in_channels, out_channels, **kwargs):
128-
super(Conv2dSum, self).__init__()
129-
seed = 2018
130-
torch.manual_seed(seed)
131-
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
132-
self.conv1 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
133-
134-
def forward(self, x):
135-
a = self.conv(x)
136-
b = self.conv1(x)
137-
return a+b
138-
139-
class Conv3dSum(nn.Module):
140-
def __init__(self, in_channels, out_channels, **kwargs):
141-
super(Conv3dSum, self).__init__()
108+
class ConvSum(nn.Module):
109+
def __init__(self, dim, in_channels, out_channels, **kwargs):
110+
super(ConvSum, self).__init__()
142111
seed = 2018
143112
torch.manual_seed(seed)
144-
self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
145-
self.conv1 = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
113+
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
114+
self.conv1 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
146115

147116
def forward(self, x):
148117
a = self.conv(x)
149118
b = self.conv1(x)
150119
return a+b
151120

152-
class CascadedConv2dBnSumRelu(nn.Module):
153-
def __init__(self, in_channels, mid_channels, out_channels, **kwargs):
154-
super(CascadedConv2dBnSumRelu, self).__init__()
155-
torch.manual_seed(2018)
156-
self.conv = nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs)
157-
self.conv1 = nn.Conv2d(
158-
mid_channels, out_channels, bias=False, padding=1, **kwargs)
159-
self.conv2 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
160-
self.bn = nn.BatchNorm2d(mid_channels, eps=0.001)
161-
self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001)
162-
self.bn2 = nn.BatchNorm2d(out_channels, eps=0.001)
163-
164-
def forward(self, x):
165-
a = self.conv(x)
166-
a = self.bn(a)
167-
a = F.relu(a, inplace=True)
168-
a = self.conv1(a)
169-
a = self.bn1(a)
170-
b = self.conv2(x)
171-
b = self.bn2(b)
172-
return F.relu(a.add_(b), inplace=True)
173-
174-
class CascadedConv3dBnSumRelu(nn.Module):
175-
def __init__(self, in_channels, mid_channels, out_channels, **kwargs):
176-
super(CascadedConv3dBnSumRelu, self).__init__()
121+
class CascadedConvBnSumRelu(nn.Module):
122+
def __init__(self, dim, in_channels, mid_channels, out_channels, **kwargs):
123+
super(CascadedConvBnSumRelu, self).__init__()
177124
torch.manual_seed(2018)
178-
self.conv = nn.Conv3d(in_channels, mid_channels, bias=False, **kwargs)
179-
self.conv1 = nn.Conv3d(
125+
self.conv = conv_module[dim](in_channels, mid_channels, bias=False, **kwargs)
126+
self.conv1 = conv_module[dim](
180127
mid_channels, out_channels, bias=False, padding=1, **kwargs)
181-
self.conv2 = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs)
182-
self.bn = nn.BatchNorm3d(mid_channels, eps=0.001)
183-
self.bn1 = nn.BatchNorm3d(out_channels, eps=0.001)
184-
self.bn2 = nn.BatchNorm3d(out_channels, eps=0.001)
128+
self.conv2 = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
129+
self.bn = bn_module[dim](mid_channels, eps=0.001)
130+
self.bn1 = bn_module[dim](out_channels, eps=0.001)
131+
self.bn2 = bn_module[dim](out_channels, eps=0.001)
185132

186133
def forward(self, x):
187134
a = self.conv(x)
@@ -280,93 +227,93 @@ def _test_output_bf16(self, model, x, kind=None, prec=None):
280227

281228
def test_output_conv_bn_2d(self):
282229
self._test_output(
283-
Conv2dBatchNorm2d_Fixed(3, 32, kernel_size=3, stride=1),
230+
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
284231
torch.randn(32, 3, 224, 224),
285232
kind="aten::conv2d")
286233
self._test_output_bf16(
287-
Conv2dBatchNorm2d_Fixed(3, 32, kernel_size=3, stride=1),
234+
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
288235
torch.randn(32, 3, 224, 224),
289236
kind="aten::conv2d",
290237
prec=0.02)
291238

292239

293240
def test_output_conv_bn_3d(self):
294241
self._test_output(
295-
Conv3dBatchNorm3d_Fixed(3, 32, kernel_size=3, stride=1),
242+
ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),
296243
torch.randn(32, 3, 112, 112, 112),
297244
kind="aten::conv3d")
298245
self._test_output_bf16(
299-
Conv3dBatchNorm3d_Fixed(3, 32, kernel_size=3, stride=1),
246+
ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),
300247
torch.randn(32, 3, 112, 112, 112),
301248
kind="aten::conv3d",
302249
prec=0.02)
303250

304251

305252
def test_output_conv_relu_2d(self):
306253
self._test_output(
307-
Conv2dRelu_Fixed(3, 32, kernel_size=3, stride=1),
254+
ConvRelu_Fixed(2, 3, 32, kernel_size=3, stride=1),
308255
torch.randn(32, 3, 224, 224),
309256
kind="ipex::conv2d_relu")
310257
self._test_output_bf16(
311-
Conv2dRelu_Fixed(3, 32, kernel_size=3, stride=1),
258+
ConvRelu_Fixed(2, 3, 32, kernel_size=3, stride=1),
312259
torch.randn(32, 3, 224, 224),
313260
kind="ipex::conv2d_relu")
314261

315262

316263
def test_output_conv_relu_3d(self):
317264
self._test_output(
318-
Conv3dRelu_Fixed(3, 32, kernel_size=3, stride=1),
265+
ConvRelu_Fixed(3, 3, 32, kernel_size=3, stride=1),
319266
torch.randn(32, 3, 112, 112, 112),
320267
kind="ipex::conv3d_relu")
321268
self._test_output_bf16(
322-
Conv3dRelu_Fixed(3, 32, kernel_size=3, stride=1),
269+
ConvRelu_Fixed(3, 3, 32, kernel_size=3, stride=1),
323270
torch.randn(32, 3, 112, 112, 112),
324271
kind="ipex::conv3d_relu")
325272

326273

327274
def test_output_conv_sum_2d(self):
328275
self._test_output(
329-
Conv2dSum(3, 32, kernel_size=3, stride=1),
276+
ConvSum(2, 3, 32, kernel_size=3, stride=1),
330277
torch.randn(32, 3, 224, 224),
331278
kind="ipex::conv2d_sum")
332279
self._test_output_bf16(
333-
Conv2dSum(3, 32, kernel_size=3, stride=1),
280+
ConvSum(2, 3, 32, kernel_size=3, stride=1),
334281
torch.randn(32, 3, 224, 224),
335282
kind="ipex::conv2d_sum",
336283
prec=0.04)
337284

338285

339286
def test_output_conv_sum_3d(self):
340287
self._test_output(
341-
Conv3dSum(3, 32, kernel_size=3, stride=1),
288+
ConvSum(3, 3, 32, kernel_size=3, stride=1),
342289
torch.randn(32, 3, 112, 112, 112),
343290
kind="ipex::conv3d_sum")
344291
self._test_output_bf16(
345-
Conv3dSum(3, 32, kernel_size=3, stride=1),
292+
ConvSum(3, 3, 32, kernel_size=3, stride=1),
346293
torch.randn(32, 3, 112, 112, 112),
347294
kind="ipex::conv3d_sum",
348295
prec=0.04)
349296

350297

351298
def test_output_cascaded_conv_bn_sum_relu_2d(self):
352299
self._test_output(
353-
CascadedConv2dBnSumRelu(3, 64, 32, kernel_size=3, stride=1),
300+
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
354301
torch.rand(32, 3, 224, 224),
355302
kind="ipex::conv2d_sum_relu")
356303
self._test_output_bf16(
357-
CascadedConv2dBnSumRelu(3, 64, 32, kernel_size=3, stride=1),
304+
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
358305
torch.rand(32, 3, 224, 224),
359306
kind="ipex::conv2d_sum_relu",
360307
prec=0.02)
361308

362309

363310
def test_output_cascaded_conv_bn_sum_relu_3d(self):
364311
self._test_output(
365-
CascadedConv3dBnSumRelu(3, 64, 32, kernel_size=3, stride=1),
312+
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
366313
torch.rand(32, 3, 112, 112, 112),
367314
kind="ipex::conv3d_sum_relu")
368315
self._test_output_bf16(
369-
CascadedConv3dBnSumRelu(3, 64, 32, kernel_size=3, stride=1),
316+
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
370317
torch.rand(32, 3, 112, 112, 112),
371318
kind="ipex::conv3d_sum_relu",
372319
prec=0.02)

0 commit comments

Comments
 (0)