@@ -305,6 +305,86 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
305305}
306306
307307// -----
308+
309+ // CHECK: func @fold_linalg_index_tensor_static
310+ func.func @fold_linalg_index_tensor_static (%0: tensor <4 x16 xi32 >, %1: tensor <1 x16 xi32 >,
311+ %2: tensor <4 x1 xi32 >) -> tensor <4 x1 xi32 > {
312+ // CHECK-NEXT: linalg.generic
313+ // CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
314+ // CHECK-NOT: linalg.index 1
315+ // CHECK: %[[IDX_2:.+]] = linalg.index 2 : index
316+ // CHECK: %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]]
317+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]]
318+ // CHECK: linalg.yield %[[CAST]]
319+ %res = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>,
320+ affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>,
321+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>],
322+ iterator_types = [" parallel" , " parallel" , " reduction" ]}
323+ ins (%0 , %1 : tensor <4 x16 xi32 >, tensor <1 x16 xi32 >)
324+ outs (%2 : tensor <4 x1 xi32 >) {
325+ ^bb0 (%lhs: i32 , %rhs: i32 , %out: i32 ):
326+ %idx0 = linalg.index 0 : index
327+ %idx1 = linalg.index 1 : index
328+ %idx2 = linalg.index 2 : index
329+ %add0 = arith.addi %idx0 , %idx1 : index
330+ %add1 = arith.addi %add0 , %idx2 : index
331+ %int = arith.index_cast %add1 : index to i32
332+ linalg.yield %int : i32
333+ } -> tensor <4 x1 xi32 >
334+ return %res : tensor <4 x1 xi32 >
335+ }
336+
337+ // -----
338+
339+ // CHECK: func @fold_linalg_index_tensor_dynamic
340+ func.func @fold_linalg_index_tensor_dynamic (%0: tensor <?x1 xi32 >,
341+ %1: tensor <?x1 xi32 >) -> tensor <?x1 xi32 > {
342+ // CHECK-NEXT: linalg.generic
343+ // CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
344+ // CHECK-NOT: linalg.index 1
345+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_0]]
346+ // CHECK: linalg.yield %[[CAST]]
347+ %res = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>,
348+ affine_map <(d0 , d1 ) -> (d1 , d1 )>],
349+ iterator_types = [" parallel" , " parallel" ]}
350+ ins (%0 : tensor <?x1 xi32 >)
351+ outs (%1 : tensor <?x1 xi32 >) {
352+ ^bb0 (%lhs: i32 , %out: i32 ):
353+ %idx0 = linalg.index 0 : index
354+ %idx1 = linalg.index 1 : index
355+ %add = arith.addi %idx0 , %idx1 : index
356+ %int = arith.index_cast %add : index to i32
357+ linalg.yield %int : i32
358+ } -> tensor <?x1 xi32 >
359+ return %res : tensor <?x1 xi32 >
360+ }
361+
362+ // -----
363+
364+ // CHECK: func @fold_linalg_index_memref
365+ func.func @fold_linalg_index_memref (%0: memref <1 x?xi32 >, %1: memref <1 x?xi32 >) {
366+ // CHECK-NEXT: linalg.generic
367+ // CHECK-NOT: linalg.index 0
368+ // CHECK: %[[IDX_1:.+]] = linalg.index 1 : index
369+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_1]]
370+ // CHECK: linalg.yield %[[CAST]]
371+ linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>,
372+ affine_map <(d0 , d1 ) -> (d1 , d1 )>],
373+ iterator_types = [" parallel" , " parallel" ]}
374+ ins (%0 : memref <1 x?xi32 >)
375+ outs (%1 : memref <1 x?xi32 >) {
376+ ^bb0 (%lhs: i32 , %out: i32 ):
377+ %idx0 = linalg.index 0 : index
378+ %idx1 = linalg.index 1 : index
379+ %add = arith.addi %idx0 , %idx1 : index
380+ %int = arith.index_cast %add : index to i32
381+ linalg.yield %int : i32
382+ }
383+ return
384+ }
385+
386+ // -----
387+
308388// CHECK-LABEL: func @fold_fill_reshape()
309389func.func @fold_fill_reshape () -> tensor <6 x4 xf32 > {
310390 %zero = arith.constant 0.0 : f32
0 commit comments