|
1 | 1 | // RUN: triton-opt %s --convert-triton-intel-gpu-to-llvm | FileCheck %s |
2 | 2 |
|
| 3 | +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}> |
| 4 | + |
3 | 5 | // COM: check that the spirv target env is inserted |
4 | 6 | // CHECK: module attributes {{{.*}}spirv.target_env{{.*}}#spirv.resource_limits<subgroup_size = 16> |
5 | | -module attributes { "ttg.threads-per-warp" = 16 : i32, "ttg.num-warps" = 4 : i32 } { } |
| 7 | +module attributes { "ttg.threads-per-warp" = 16 : i32, "ttg.num-warps" = 4 : i32 } { |
| 8 | + // As the assert message is shared, a single instance is emitted. |
| 9 | + |
| 10 | + // CHECK-DAG: llvm.mlir.global internal constant @assertFunc_("unknown\00") {addr_space = 1 : i32} |
| 11 | + // CHECK-DAG: llvm.mlir.global internal constant @assertFile_("{{.*}}/test/Conversion/intel/tritonintelgpu_to_llvm.mlir\00") {addr_space = 1 : i32} |
| 12 | + // CHECK-DAG: llvm.mlir.global internal constant @assertMessage_("assert text\00") {addr_space = 1 : i32} |
| 13 | + // CHECK-DAG: llvm.mlir.global internal constant @assertMessage_3("different assert text\00") {addr_space = 1 : i32} |
| 14 | + // CHECK-DAG: llvm.func spir_funccc @__assert_fail(!llvm.ptr<4>, !llvm.ptr<4>, i32, !llvm.ptr<4>) |
| 15 | + |
| 16 | + // CHECK: llvm.func spir_kernelcc @assert(%[[VAL_0:.*]]: !llvm.struct<(i1)>, %[[VAL_1:.*]]: !llvm.struct<(i1)>, %[[VAL_2:.*]]: !llvm.struct<(i1)>) |
| 17 | + tt.func public @assert(%arg0: tensor<1xi1, #blocked>, %arg1: tensor<1xi1, #blocked>, %arg2: tensor<1xi1, #blocked>) { |
| 18 | + // CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(i1)> |
| 19 | + // CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(false) : i1 |
| 20 | + // CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(false) : i1 |
| 21 | + // CHECK: %[[VAL_6:.*]] = llvm.icmp "eq" %[[VAL_3]], %[[VAL_5]] : i1 |
| 22 | + // CHECK: %[[VAL_7:.*]] = llvm.or %[[VAL_4]], %[[VAL_6]] : i1 |
| 23 | + // CHECK: llvm.cond_br %[[VAL_7]], ^bb1, ^bb2 |
| 24 | + // CHECK: ^bb1: |
| 25 | + // CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @assertMessage_ : !llvm.ptr<1> |
| 26 | + // CHECK: %[[VAL_9:.*]] = llvm.getelementptr %[[VAL_8]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 27 | + // CHECK: %[[VAL_10:.*]] = llvm.mlir.addressof @assertFile_ : !llvm.ptr<1> |
| 28 | + // CHECK: %[[VAL_11:.*]] = llvm.getelementptr %[[VAL_10]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 29 | + // CHECK: %[[VAL_12:.*]] = llvm.mlir.addressof @assertFunc_ : !llvm.ptr<1> |
| 30 | + // CHECK: %[[VAL_13:.*]] = llvm.getelementptr %[[VAL_12]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 31 | + // CHECK: %[[VAL_14:.*]] = llvm.mlir.constant |
| 32 | + // CHECK: %[[VAL_15:.*]] = llvm.addrspacecast %[[VAL_9]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 33 | + // CHECK: %[[VAL_16:.*]] = llvm.addrspacecast %[[VAL_11]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 34 | + // CHECK: %[[VAL_17:.*]] = llvm.addrspacecast %[[VAL_13]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 35 | + // CHECK: llvm.call spir_funccc @__assert_fail(%[[VAL_15]], %[[VAL_16]], %[[VAL_14]], %[[VAL_17]]) : (!llvm.ptr<4>, !llvm.ptr<4>, i32, !llvm.ptr<4>) -> () |
| 36 | + // CHECK: llvm.br ^bb2 |
| 37 | + // CHECK: ^bb2: |
| 38 | + // CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32 |
| 39 | + // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[VAL_18]]) {convergent, no_unwind, will_return} : (i32) -> () |
| 40 | + tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> |
| 41 | + // CHECK: %[[VAL_19:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(i1)> |
| 42 | + // CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(false) : i1 |
| 43 | + // CHECK: %[[VAL_21:.*]] = llvm.mlir.constant(false) : i1 |
| 44 | + // CHECK: %[[VAL_22:.*]] = llvm.icmp "eq" %[[VAL_19]], %[[VAL_21]] : i1 |
| 45 | + // CHECK: %[[VAL_23:.*]] = llvm.or %[[VAL_20]], %[[VAL_22]] : i1 |
| 46 | + // CHECK: llvm.cond_br %[[VAL_23]], ^bb3, ^bb4 |
| 47 | + // CHECK: ^bb3: |
| 48 | + // CHECK: %[[VAL_24:.*]] = llvm.mlir.addressof @assertMessage_ : !llvm.ptr<1> |
| 49 | + // CHECK: %[[VAL_25:.*]] = llvm.getelementptr %[[VAL_24]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 50 | + // CHECK: %[[VAL_26:.*]] = llvm.mlir.addressof @assertFile_ : !llvm.ptr<1> |
| 51 | + // CHECK: %[[VAL_27:.*]] = llvm.getelementptr %[[VAL_26]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 52 | + // CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @assertFunc_ : !llvm.ptr<1> |
| 53 | + // CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 54 | + // CHECK: %[[VAL_30:.*]] = llvm.mlir.constant |
| 55 | + // CHECK: %[[VAL_31:.*]] = llvm.addrspacecast %[[VAL_25]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 56 | + // CHECK: %[[VAL_32:.*]] = llvm.addrspacecast %[[VAL_27]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 57 | + // CHECK: %[[VAL_33:.*]] = llvm.addrspacecast %[[VAL_29]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 58 | + // CHECK: llvm.call spir_funccc @__assert_fail(%[[VAL_31]], %[[VAL_32]], %[[VAL_30]], %[[VAL_33]]) : (!llvm.ptr<4>, !llvm.ptr<4>, i32, !llvm.ptr<4>) -> () |
| 59 | + // CHECK: llvm.br ^bb4 |
| 60 | + // CHECK: ^bb4: |
| 61 | + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : i32 |
| 62 | + // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[VAL_34]]) {convergent, no_unwind, will_return} : (i32) -> () |
| 63 | + tt.assert %arg1, "assert text" : tensor<1xi1, #blocked> |
| 64 | + // CHECK: %[[VAL_35:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.struct<(i1)> |
| 65 | + // CHECK: %[[VAL_36:.*]] = llvm.mlir.constant(false) : i1 |
| 66 | + // CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(false) : i1 |
| 67 | + // CHECK: %[[VAL_38:.*]] = llvm.icmp "eq" %[[VAL_35]], %[[VAL_37]] : i1 |
| 68 | + // CHECK: %[[VAL_39:.*]] = llvm.or %[[VAL_36]], %[[VAL_38]] : i1 |
| 69 | + // CHECK: llvm.cond_br %[[VAL_39]], ^bb5, ^bb6 |
| 70 | + // CHECK: ^bb5: |
| 71 | + // CHECK: %[[VAL_40:.*]] = llvm.mlir.addressof @assertMessage_3 : !llvm.ptr<1> |
| 72 | + // CHECK: %[[VAL_41:.*]] = llvm.getelementptr %[[VAL_40]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 73 | + // CHECK: %[[VAL_42:.*]] = llvm.mlir.addressof @assertFile_ : !llvm.ptr<1> |
| 74 | + // CHECK: %[[VAL_43:.*]] = llvm.getelementptr %[[VAL_42]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 75 | + // CHECK: %[[VAL_44:.*]] = llvm.mlir.addressof @assertFunc_ : !llvm.ptr<1> |
| 76 | + // CHECK: %[[VAL_45:.*]] = llvm.getelementptr %[[VAL_44]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i8 |
| 77 | + // CHECK: %[[VAL_46:.*]] = llvm.mlir.constant |
| 78 | + // CHECK: %[[VAL_47:.*]] = llvm.addrspacecast %[[VAL_41]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 79 | + // CHECK: %[[VAL_48:.*]] = llvm.addrspacecast %[[VAL_43]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 80 | + // CHECK: %[[VAL_49:.*]] = llvm.addrspacecast %[[VAL_45]] : !llvm.ptr<1> to !llvm.ptr<4> |
| 81 | + // CHECK: llvm.call spir_funccc @__assert_fail(%[[VAL_47]], %[[VAL_48]], %[[VAL_46]], %[[VAL_49]]) : (!llvm.ptr<4>, !llvm.ptr<4>, i32, !llvm.ptr<4>) -> () |
| 82 | + // CHECK: llvm.br ^bb6 |
| 83 | + // CHECK: ^bb6: |
| 84 | + // CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(1 : i32) : i32 |
| 85 | + // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[VAL_50]]) {convergent, no_unwind, will_return} : (i32) -> () |
| 86 | + tt.assert %arg2, "different assert text" : tensor<1xi1, #blocked> |
| 87 | + tt.return |
| 88 | + } |
| 89 | +} |
0 commit comments