@@ -493,6 +493,32 @@ util.func public @conv_1d_nhc_chf(%arg0: tensor<1x3x2xf32>, %arg1: tensor<2x2x2x
493493
494494// -----
495495
496+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d2 + d5 , d3 + d6 , d0 , d4 )>
497+ #map4 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d5 , d6 , d0 , d1 )>
498+ #map5 = affine_map <(d0 , d1 , d2 , d3 , d4 , d5 , d6 ) -> (d0 , d1 , d2 , d3 , d4 )>
499+ util.func public @conv_2d_no_input_channel (%arg0: tensor <61 x93 x16 x64 xbf16 >, %arg1: tensor <59 x91 x16 x56 xbf16 >, %arg2: tensor <16 x56 x3 x3 x64 xf32 >) -> tensor <16 x56 x3 x3 x64 xf32 > {
500+ %0 = linalg.generic {index ing_maps = [#map3 , #map4 , #map5 ], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " reduction" , " reduction" ]} ins (%arg0 , %arg1 : tensor <61 x93 x16 x64 xbf16 >, tensor <59 x91 x16 x56 xbf16 >) outs (%arg2 : tensor <16 x56 x3 x3 x64 xf32 >) {
501+ ^bb0 (%in: bf16 , %in_0: bf16 , %out: f32 ):
502+ %1 = arith.extf %in : bf16 to f32
503+ %2 = arith.extf %in_0 : bf16 to f32
504+ %3 = arith.mulf %1 , %2 : f32
505+ %4 = arith.addf %out , %3 : f32
506+ linalg.yield %4 : f32
507+ } -> tensor <16 x56 x3 x3 x64 xf32 >
508+ util.return %0 : tensor <16 x56 x3 x3 x64 xf32 >
509+ }
510+
511+ // CHECK: util.func public @conv_2d_no_input_channel(
512+ // CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
513+ // CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [59, 91]
514+ // CHECK-SAME: m_offset = [0, 0] * [3, 1] k_offset = [0] * [1]
515+ // CHECK-SAME: batch_pos = [3, 2] m_pos = [0, 1] k_pos = []
516+ // CHECK-SAME: input_k_perm = [0, 1] output_perm = [2, 3, 4, 1, 0]
517+ // CHECK-SAME: ins({{.*}} : tensor<61x93x16x64xbf16>)
518+ // CHECK-SAME: outs({{.*}} : tensor<3x3x5369x16x64xbf16>) -> tensor<3x3x5369x16x64xbf16>
519+
520+ // -----
521+
496522util.func public @conv_2d_nhwgc_gfhwc (%arg0: tensor <2 x10 x10 x7 x4 xf32 >, %arg1: tensor <7 x16 x3 x3 x4 xf32 >, %arg2: tensor <2 x8 x8 x7 x16 xf32 >) -> tensor <2 x8 x8 x7 x16 xf32 > {
497523 %0 = linalg.conv_2d_nhwgc_gfhwc
498524 {dilations = dense <1 > : tensor <2 xi64 >, strides = dense <1 > : tensor <2 xi64 > }
0 commit comments