@@ -1421,6 +1421,28 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
1421
1421
return %0 : !torch.vtensor <[1 ,12 ,5 ,5 ],f32 >
1422
1422
}
1423
1423
1424
+ // -----
1425
+ // CHECK-LABEL: func.func @torch.aten.where.self_differing_rank_inputs(
1426
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4],i1>,
1427
+ // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>,
1428
+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,3,1,1,5,4],f32>) -> !torch.vtensor<[1,3,1,1,5,4],f32> {
1429
+ // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1,3,1,1,5,4],f32> -> tensor<1x3x1x1x5x4xf32>
1430
+ // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],f32> -> tensor<f32>
1431
+ // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4],i1> -> tensor<5x4xi1>
1432
+ // CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
1433
+ // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
1434
+ // CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 1, 5, 4]> : tensor<6xindex>} : () -> !tosa.shape<6>
1435
+ // CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor<5x4xi1>, !tosa.shape<6>) -> tensor<1x1x1x1x5x4xi1>
1436
+ // CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
1437
+ // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_10]] : (tensor<1x1xf32>, !tosa.shape<6>) -> tensor<1x1x1x1x1x1xf32>
1438
+ // CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_9]], %[[VAL_11]], %[[VAL_3]] : (tensor<1x1x1x1x5x4xi1>, tensor<1x1x1x1x1x1xf32>, tensor<1x3x1x1x5x4xf32>) -> tensor<1x3x1x1x5x4xf32>
1439
+ // CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x3x1x1x5x4xf32> -> !torch.vtensor<[1,3,1,1,5,4],f32>
1440
+ // CHECK: return %[[VAL_13]]
1441
+ func.func @torch.aten.where.self_differing_rank_inputs (%40: !torch.vtensor <[5 ,4 ],i1 >, %41: !torch.vtensor <[],f32 >, %38 : !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >) -> (!torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >) {
1442
+ %42 = torch.aten.where.self %40 , %41 , %38 : !torch.vtensor <[5 ,4 ],i1 >, !torch.vtensor <[],f32 >, !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 > -> !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >
1443
+ return %42: !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >
1444
+ }
1445
+
1424
1446
// -----
1425
1447
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
1426
1448
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
@@ -3780,6 +3802,91 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp
3780
3802
return %5 : !torch.vtensor <[1 ,32 ,75 ,75 ],f32 >
3781
3803
}
3782
3804
3805
+
3806
+ // -----
3807
+
3808
+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch(
3809
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,224,224],f32>) -> !torch.vtensor<[?,32,112,112],f32> {
3810
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,224,224],f32> -> tensor<?x3x224x224xf32>
3811
+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3812
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3813
+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3814
+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3815
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 2
3816
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3817
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3818
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3819
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3820
+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3821
+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x224x224xf32>) -> tensor<?x224x224x3xf32>
3822
+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3823
+ // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3824
+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3825
+ // CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 2, 2>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x112x112x32xf32>
3826
+ // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x112x112x32xf32>) -> tensor<?x32x112x112xf32>
3827
+ // CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<?x32x112x112xf32> to tensor<?x32x112x112xf32>
3828
+ // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<?x32x112x112xf32> -> !torch.vtensor<[?,32,112,112],f32>
3829
+ // CHECK: return %[[VAL_19]]
3830
+
3831
+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,224 ,224 ],f32 >) -> !torch.vtensor <[?,32 ,112 ,112 ],f32 > {
3832
+ %false = torch.constant.bool false
3833
+ %int1 = torch.constant.int 1
3834
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3835
+ %none = torch.constant.none
3836
+ %int2 = torch.constant.int 2
3837
+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
3838
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3839
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3840
+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3841
+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,224 ,224 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3842
+ return %5 : !torch.vtensor <[?,32 ,112 ,112 ],f32 >
3843
+ }
3844
+
3845
+
3846
+ // -----
3847
+
3848
+ // CHECK-LABEL: func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch(
3849
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,3,225,225],f32>) -> !torch.vtensor<[?,32,75,75],f32> {
3850
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,225,225],f32> -> tensor<?x3x225x225xf32>
3851
+ // CHECK: %[[VAL_2:.*]] = torch.constant.bool false
3852
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
3853
+ // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense_resource<torch_tensor_32_3_3_3_torch.float32> : tensor<32x3x3x3xf32>}> : () -> tensor<32x3x3x3xf32>
3854
+ // CHECK: %[[VAL_5:.*]] = torch.constant.none
3855
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 3
3856
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3857
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3858
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
3859
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
3860
+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
3861
+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
3862
+ // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3863
+ // CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3864
+ // CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3865
+ // CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3866
+ // CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3867
+ // CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3868
+ // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
3869
+ // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3870
+ // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
3871
+ // CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 0, 1, 0>, stride = array<i64: 3, 3>} : (tensor<?x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x75x75x32xf32>
3872
+ // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<?x75x75x32xf32>) -> tensor<?x32x75x75xf32>
3873
+ // CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<?x32x75x75xf32> to tensor<?x32x75x75xf32>
3874
+ // CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<?x32x75x75xf32> -> !torch.vtensor<[?,32,75,75],f32>
3875
+ // CHECK: return %[[VAL_25]]
3876
+ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_input_dynamic_batch (%arg0: !torch.vtensor <[?,3 ,225 ,225 ],f32 >) -> !torch.vtensor <[?,32 ,75 ,75 ],f32 > {
3877
+ %false = torch.constant.bool false
3878
+ %int1 = torch.constant.int 1
3879
+ %0 = torch.vtensor.literal (dense_resource <torch_tensor_32_3_3_3_torch.float32 > : tensor <32 x3 x3 x3 xf32 >) : !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >
3880
+ %none = torch.constant.none
3881
+ %int3 = torch.constant.int 3
3882
+ %1 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
3883
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3884
+ %3 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
3885
+ %4 = torch.prim.ListConstruct : () -> !torch.list <int >
3886
+ %5 = torch.aten.convolution %arg0 , %0 , %none , %1 , %2 , %3 , %false , %4 , %int1 : !torch.vtensor <[?,3 ,225 ,225 ],f32 >, !torch.vtensor <[32 ,3 ,3 ,3 ],f32 >, !torch.none , !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3887
+ return %5 : !torch.vtensor <[?,32 ,75 ,75 ],f32 >
3888
+ }
3889
+
3783
3890
// -----
3784
3891
3785
3892
// CHECK-LABEL: func.func @torch.aten.max_pool2d$zero_pad_with_sliced_input(
0 commit comments