|
| 1 | +// RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s |
| 2 | + |
| 3 | +module { |
| 4 | + // COM: Loop containing a tensor descriptor load operation using a loop invariant tensor descriptor. |
| 5 | + tt.func public @load_in_loop1(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i32) { |
| 6 | + %c0 = arith.constant 0 : index |
| 7 | + %c1 = arith.constant 1 : index |
| 8 | + %c10 = arith.constant 10 : index |
| 9 | + %c1_i64 = arith.constant 1 : i64 |
| 10 | + %c8_i32 = arith.constant 8 : i32 |
| 11 | + %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16> |
| 12 | + %0 = arith.extsi %arg2 : i32 to i64 |
| 13 | + %tdesc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f16>, <tensor<16x32xf16>> |
| 14 | + %tdesc_out, %sum_out = scf.for %i = %c0 to %c10 step %c1 iter_args(%ptr_iter = %tdesc, %sum_iter = %cst) -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) { |
| 15 | + %cast_i = arith.index_cast %i : index to i32 |
| 16 | + %load1 = tt.descriptor_load %ptr_iter[%c8_i32, %cast_i] : !tt.tensordesc<tensor<16x32xf16>> -> tensor<16x32xf16> |
| 17 | + %sum_next = arith.addf %sum_iter, %load1 : tensor<16x32xf16> |
| 18 | + scf.yield %ptr_iter, %sum_next : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16> |
| 19 | + } |
| 20 | + tt.return |
| 21 | + } |
| 22 | + // CHECK: tt.func public @load_in_loop1([[PARAM_0:%.+]]: !tt.ptr<f16>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) { |
| 23 | + // CHECK-NOT: tt.make_tensor_descriptor |
| 24 | + // CHECK-NOT: tt.descriptor_load |
| 25 | + // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 |
| 26 | + // CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32 |
| 27 | + // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.000000e+00> : tensor<16x32xf16> |
| 28 | + // CHECK-DAG: [[EXTSI_PARAM_2a:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 |
| 29 | + // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = {{.*}}, [[VAR_arg2:%.+]] = [[CST]]) -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) { |
| 30 | + // CHECK-DAG: [[IDX_CAST_1:%.+]] = arith.index_cast [[IV]] : index to i32 |
| 31 | + // CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64 |
| 32 | + // CHECK-DAG: [[EXTSI_PARAM_2b:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64 |
| 33 | + // CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2b]]], {{\[}}[[EXTSI_PARAM_2a]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[IDX_CAST_1]]] {{.*}} : <tensor<16x32xf16>> |
| 34 | + // CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr<tensor<16x32xf16>> |
| 35 | + // CHECK: [[ADD:%.+]] = arith.addf [[VAR_arg2]], [[LOAD]] : tensor<16x32xf16> |
| 36 | + // CHECK: scf.yield {{.*}}, [[ADD]] : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16> |
| 37 | + // CHECK: } |
| 38 | + // CHECK: tt.return |
| 39 | + // CHECK: } |
| 40 | + |
| 41 | + // COM: Loop containing a tensor descriptor load operation using a loop variant tensor descriptor. |
| 42 | + tt.func public @load_in_loop2(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i32) { |
| 43 | + %c0 = arith.constant 0 : index |
| 44 | + %c1 = arith.constant 1 : index |
| 45 | + %c10 = arith.constant 10 : index |
| 46 | + %c1_i64 = arith.constant 1 : i64 |
| 47 | + %c8_i32 = arith.constant 8 : i32 |
| 48 | + %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16> |
| 49 | + %0 = arith.extsi %arg2 : i32 to i64 |
| 50 | + %tdesc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f16>, <tensor<16x32xf16>> |
| 51 | + %tdesc_out, %sum_out = scf.for %i = %c0 to %c10 step %c1 iter_args(%ptr_iter = %tdesc, %sum_iter = %cst) -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) { |
| 52 | + %cast_i = arith.index_cast %i : index to i32 |
| 53 | + %load1 = tt.descriptor_load %ptr_iter[%c8_i32, %cast_i] : !tt.tensordesc<tensor<16x32xf16>> -> tensor<16x32xf16> |
| 54 | + %sum_next = arith.addf %sum_iter, %load1 : tensor<16x32xf16> |
| 55 | + %tdesc_in_loop = tt.make_tensor_descriptor %arg0, [%arg2, %arg1], [%c1_i64, %0] : <f16>, <tensor<16x32xf16>> |
| 56 | + %cmp = arith.cmpi eq, %cast_i, %c8_i32 : i32 |
| 57 | + %sel_tdesc = arith.select %cmp, %ptr_iter, %tdesc_in_loop : !tt.tensordesc<tensor<16x32xf16>> |
| 58 | + scf.yield %sel_tdesc, %sum_next : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16> |
| 59 | + } |
| 60 | + tt.return |
| 61 | + } |
| 62 | + // CHECK: tt.func public @load_in_loop2({{.*}}) { |
| 63 | + // CHECK-NOT: tt.make_tensor_ptr |
| 64 | + // CHECK-NOT: tt.load |
| 65 | + // CHECK: tt.make_tensor_descriptor |
| 66 | + // CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) { |
| 67 | + // CHECK: tt.descriptor_load |
| 68 | + // CHECK: tt.make_tensor_descriptor |
| 69 | + // CHECK: } |
| 70 | + // CHECK: tt.return |
| 71 | + // CHECK: } |
| 72 | + |
| 73 | + // COM: Loop yields a tensor descriptor used by a tensor descriptor load. |
| 74 | + tt.func public @load_uses_loop_result(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i32) { |
| 75 | + %c0 = arith.constant 0 : index |
| 76 | + %c1 = arith.constant 1 : index |
| 77 | + %c10 = arith.constant 10 : index |
| 78 | + %c1_i64 = arith.constant 1 : i64 |
| 79 | + %c8_i32 = arith.constant 8 : i32 |
| 80 | + %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16> |
| 81 | + %0 = arith.extsi %arg2 : i32 to i64 |
| 82 | + %tdesc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f16>, <tensor<16x32xf16>> |
| 83 | + %tdesc_out = scf.for %i = %c0 to %c10 step %c1 iter_args(%ptr_iter = %tdesc) -> (!tt.tensordesc<tensor<16x32xf16>>) { |
| 84 | + scf.yield %ptr_iter : !tt.tensordesc<tensor<16x32xf16>> |
| 85 | + } |
| 86 | + %cast_c10 = arith.index_cast %c10 : index to i32 |
| 87 | + %load2 = tt.descriptor_load %tdesc_out[%c8_i32, %cast_c10] : !tt.tensordesc<tensor<16x32xf16>> -> tensor<16x32xf16> |
| 88 | + tt.return |
| 89 | + } |
| 90 | + // CHECK: tt.func public @load_uses_loop_result({{.*}}) { |
| 91 | + // CHECK-NOT: tt.make_tensor_ptr |
| 92 | + // CHECK-NOT: tt.load |
| 93 | + // CHECK: tt.make_tensor_descriptor |
| 94 | + // CHECK: tt.descriptor_load |
| 95 | + // CHECK: tt.return |
| 96 | + // CHECK: } |
| 97 | +} |
0 commit comments