@@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
2265
2265
2266
2266
// -----
2267
2267
2268
- func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2269
- %int0 = torch.constant.int 0
2270
- %int1 = torch.constant.int 1
2271
- %int3 = torch.constant.int 3
2272
- %false = torch.constant.bool false
2273
- %count_include_pad = torch.constant.bool true
2274
- %divisor_override = torch.constant.none
2275
-
2276
- %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
2277
- %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2278
- %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2279
- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2280
- %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2281
- return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2282
- }
2283
-
2284
- // -----
2285
-
2286
2268
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2287
2269
%int0 = torch.constant.int 0
2288
2270
%int1 = torch.constant.int 1
@@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
2802
2784
2803
2785
// -----
2804
2786
2805
- func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
2806
- %int1 = torch.constant.int 1
2807
- %int3 = torch.constant.int 3
2808
- %false = torch.constant.bool false
2809
- %count_include_pad = torch.constant.bool true
2810
- %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
2811
- %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2812
- %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2813
- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2814
- %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
2815
- return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
2816
- }
2817
-
2818
- // -----
2819
-
2820
2787
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
2821
2788
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
2822
2789
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4328,3 +4295,82 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
4328
4295
%0 = torch.aten.linear %arg0 , %arg1 , %arg2 : !torch.vtensor <[2 ,4 ],f16 >, !torch.vtensor <[3 ,4 ],f16 >, !torch.vtensor <[3 ],f16 > -> !torch.vtensor <[2 ,3 ],f16 >
4329
4296
return %0 : !torch.vtensor <[2 ,3 ],f16 >
4330
4297
}
4298
+
4299
+ // -----
4300
+ // CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4301
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4302
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4303
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4304
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4305
+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4306
+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4307
+ // CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4308
+ // CHECK: %[[VAL_7:.*]] = torch.constant.none
4309
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4310
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4311
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4312
+ // CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8>
4313
+ // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4314
+ // CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_1]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x192x35x35xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x192x37x37xf32>
4315
+ // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x37x37xf32>) -> tensor<1x37x37x192xf32>
4316
+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4317
+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4318
+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4319
+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4320
+ // CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4321
+ // CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4322
+ // CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4323
+ // CHECK: }
4324
+ func.func @torch.aten.avg_pool2d.count_include_pad (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
4325
+ %int0 = torch.constant.int 0
4326
+ %int1 = torch.constant.int 1
4327
+ %int3 = torch.constant.int 3
4328
+ %false = torch.constant.bool false
4329
+ %count_include_pad = torch.constant.bool true
4330
+ %divisor_override = torch.constant.none
4331
+
4332
+ %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
4333
+ %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4334
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4335
+ %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4336
+ return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4337
+ }
4338
+
4339
+ // -----
4340
+ // CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4341
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4342
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4343
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4344
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4345
+ // CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4346
+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4347
+ // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4348
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4349
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4350
+ // CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4351
+ // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4352
+ // CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4353
+ // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4354
+ // CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x512x10x1xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x512x12x1xf32>
4355
+ // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x12x1xf32>) -> tensor<1x12x1x512xf32>
4356
+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4357
+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4358
+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4359
+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4360
+ // CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4361
+ // CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4362
+ // CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4363
+ // CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4364
+ // CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4365
+ // CHECK: }
4366
+ func.func @torch.aten.avg_pool1d.count_include_pad (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
4367
+ %int1 = torch.constant.int 1
4368
+ %int3 = torch.constant.int 3
4369
+ %false = torch.constant.bool false
4370
+ %count_include_pad = torch.constant.bool true
4371
+ %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
4372
+ %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4373
+ %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4374
+ %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
4375
+ return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
4376
+ }
0 commit comments