|
| 1 | +// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s |
| 2 | +module { |
| 3 | + tt.func public @unsplat_kernel(%arg0: !tt.ptr<i32> {maia.rank = 1 : i32, tt.divisibility = 16 : i32}) attributes {noinline = false} { |
| 4 | + %cst = arith.constant dense<42> : tensor<1xi32> |
| 5 | + %c42_i32 = arith.constant 42 : i32 |
| 6 | + %0 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1x!tt.ptr<i32>> |
| 7 | + %1 = tt.load %0 : tensor<1x!tt.ptr<i32>> |
| 8 | + %2 = arith.cmpi sgt, %1, %cst : tensor<1xi32> |
| 9 | + %3 = "tt.reduce"(%2) <{axis = 0 : i32}> ({ |
| 10 | + ^bb0(%arg1: i1, %arg2: i1): |
| 11 | + tt.reduce.return %arg1 : i1 |
| 12 | + }) : (tensor<1xi1>) -> i1 |
| 13 | + scf.if %3 { |
| 14 | + tt.store %arg0, %c42_i32 : !tt.ptr<i32> |
| 15 | + } |
| 16 | + tt.return |
| 17 | + } |
| 18 | +} |
| 19 | + |
| 20 | +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> |
| 21 | +// CHECK-LABEL: func.func @unsplat_kernel |
| 22 | +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32> {maia.rank = 1 : i32, tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { |
| 23 | +// CHECK-DAG: [[CST_42_:%.+]] = arith.constant 42 : i32 |
| 24 | +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 |
| 25 | +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index |
| 26 | +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1xi32> |
| 27 | +// CHECK-NOT: separator of consecutive DAGs |
| 28 | +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_42_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32> |
| 29 | +// CHECK-DAG: [[VAR_2_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_0_]] : tensor<1xi32>) -> tensor<1xi32> |
| 30 | +// CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[PARAM_0_]] : memref<*xi32> to memref<?xi32> |
| 31 | +// CHECK-NOT: separator of consecutive DAGs |
| 32 | +// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[VAR_cast_]] restrict : memref<?xi32> to tensor<?xi32> |
| 33 | +// CHECK-DAG: [[VAR_4_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_2_]] : tensor<1xi32>) outs([[VAR_0_]] : tensor<1xi32>) { |
| 34 | +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: i32): |
| 35 | +// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index |
| 36 | +// CHECK: [[VAR_extracted_0_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[VAR_7_]]{{.}} : tensor<?xi32> |
| 37 | +// CHECK: linalg.yield [[VAR_extracted_0_]] : i32 |
| 38 | +// CHECK: } -> tensor<1xi32> |
| 39 | +// CHECK: [[VAR_5_:%.+]] = tensor.empty() : tensor<1xi1> |
| 40 | +// CHECK: [[VAR_6_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_1_]] : tensor<1xi32>, tensor<1xi32>) outs([[VAR_5_]] : tensor<1xi1>) { |
| 41 | +// CHECK: ^bb0([[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: i1): |
| 42 | +// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi sgt, [[IN_2_]], [[IN_3_]] : i32 |
| 43 | +// CHECK: linalg.yield [[VAR_7_1_]] : i1 |
| 44 | +// CHECK: } -> tensor<1xi1> |
| 45 | +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_6_]]{{.}}[[CST_0_1_]]{{.}} : tensor<1xi1> |
| 46 | +// CHECK: scf.if [[VAR_extracted_]] { |
| 47 | +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_1_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> |
| 48 | +// CHECK: affine.store [[CST_42_]], [[VAR_reinterpret_cast_]][0] : memref<1xi32, strided<[1], offset: ?>> |
| 49 | +// CHECK: } |
| 50 | +// CHECK: return |
| 51 | +// CHECK: } |
0 commit comments