@@ -69,6 +69,14 @@ func.func @conv_2d_nhwc_hwcf_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tens
6969 return %ret : tensor <?x?x?x?xf32 >
7070}
7171
72+ func.func @conv_2d_nhwc_hwcf_dual_CDCC (%arg0: tensor <?x?x?x?xf32 , #CDCC >, %arg1: tensor <?x?x?x?xf32 , #CDCC >, %arg2: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 > {
73+ %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense <1 > : tensor <2 xi64 >,
74+ strides = dense <2 > : tensor <2 xi64 >}
75+ ins (%arg0 , %arg1: tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 , #CDCC >)
76+ outs (%arg2: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 >
77+ return %ret : tensor <?x?x?x?xf32 >
78+ }
79+
7280
7381func.func @entry () {
7482 %c0 = arith.constant 0 : index
@@ -87,16 +95,28 @@ func.func @entry() {
8795
8896 %in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
8997 : tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CCCC >
98+ %filter2D_nhwc_CDCC = sparse_tensor.convert %filter2D_nhwc
99+ : tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CDCC >
90100 %in2D_nhwc_CDCC = sparse_tensor.convert %in2D_nhwc
91101 : tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CDCC >
92102
93103 %dense_ret = call @conv_2d_nhwc_hwcf (%in2D_nhwc , %filter2D_nhwc , %out2D_nhwc ) : (tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
94104 %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC (%in2D_nhwc_CCCC , %filter2D_nhwc , %out2D_nhwc ) : (tensor <?x?x?x?xf32 , #CCCC >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
95105 %CDCC_ret = call @conv_2d_nhwc_hwcf_CDCC (%in2D_nhwc_CDCC , %filter2D_nhwc , %out2D_nhwc ) : (tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
96106
107+ %dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC (%in2D_nhwc_CDCC , %filter2D_nhwc_CDCC , %out2D_nhwc )
108+ : (tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
109+
97110 // CHECK: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
98111 // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
99112 // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
113+ %v_dual = vector.transfer_read %dual_CDCC_ret [%c0 , %c0 , %c0 , %c0 ], %zero
114+ : tensor <?x?x?x?xf32 >, vector <3 x3 x3 x1 xf32 >
115+ vector.print %v_dual : vector <3 x3 x3 x1 xf32 >
116+
117+ // CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
118+ // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
119+ // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
100120 %dense_v = vector.transfer_read %dense_ret [%c0 , %c0 , %c0 , %c0 ], %zero
101121 : tensor <?x?x?x?xf32 >, vector <3 x3 x3 x1 xf32 >
102122 vector.print %dense_v : vector <3 x3 x3 x1 xf32 >
@@ -120,6 +140,7 @@ func.func @entry() {
120140 bufferization.dealloc_tensor %filter2D_nhwc : tensor <?x?x?x?xf32 >
121141 bufferization.dealloc_tensor %out2D_nhwc : tensor <?x?x?x?xf32 >
122142
143+ bufferization.dealloc_tensor %filter2D_nhwc_CDCC : tensor <?x?x?x?xf32 , #CDCC >
123144 bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor <?x?x?x?xf32 , #CCCC >
124145 bufferization.dealloc_tensor %in2D_nhwc_CDCC : tensor <?x?x?x?xf32 , #CDCC >
125146 return
0 commit comments