|
| 1 | +// RUN: iree-opt --iree-gpu-test-target=gfx942 --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target))" %s --split-input-file | FileCheck %s |
| 2 | + |
| 3 | +func.func @fused_contraction_1(%arg0: tensor<2x4096x640xf16>, |
| 4 | + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, |
| 5 | + %arg3 : tensor<10x64x640xf16>) |
| 6 | + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>) { |
| 7 | + %11 = tensor.empty() : tensor<2x10x4096x64xf16> |
| 8 | + %12 = tensor.empty() : tensor<2x10x4096x64xf32> |
| 9 | + %cst = arith.constant 0.0: f32 |
| 10 | + %13 = linalg.fill ins(%cst : f32) |
| 11 | + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> |
| 12 | + %14:3 = linalg.generic { |
| 13 | + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, |
| 14 | + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, |
| 15 | + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, |
| 16 | + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, |
| 17 | + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, |
| 18 | + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, |
| 19 | + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], |
| 20 | + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} |
| 21 | + ins(%arg0, %arg1, %arg2, %arg3 |
| 22 | + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, |
| 23 | + tensor<10x64x640xf16>) |
| 24 | + outs(%13, %13, %13 |
| 25 | + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) { |
| 26 | + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): |
| 27 | + %18 = arith.extf %in : f16 to f32 |
| 28 | + %19 = arith.extf %in_0 : f16 to f32 |
| 29 | + %20 = arith.mulf %18, %19 : f32 |
| 30 | + %21 = arith.addf %out, %20 : f32 |
| 31 | + %22 = arith.extf %in_1 : f16 to f32 |
| 32 | + %23 = arith.mulf %18, %22 : f32 |
| 33 | + %24 = arith.addf %out_3, %23 : f32 |
| 34 | + %25 = arith.extf %in_2 : f16 to f32 |
| 35 | + %26 = arith.mulf %18, %25 : f32 |
| 36 | + %27 = arith.addf %out_4, %26 : f32 |
| 37 | + linalg.yield %21, %24, %27 : f32, f32, f32 |
| 38 | + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>) |
| 39 | + %15 = linalg.generic { |
| 40 | + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 41 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], |
| 42 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 43 | + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { |
| 44 | + ^bb0(%in: f32, %out: f16): |
| 45 | + %18 = arith.truncf %in : f32 to f16 |
| 46 | + linalg.yield %18 : f16 |
| 47 | + } -> tensor<2x10x4096x64xf16> |
| 48 | + %16 = linalg.generic { |
| 49 | + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 50 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], |
| 51 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 52 | + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { |
| 53 | + ^bb0(%in: f32, %out: f16): |
| 54 | + %18 = arith.truncf %in : f32 to f16 |
| 55 | + linalg.yield %18 : f16 |
| 56 | + } -> tensor<2x10x4096x64xf16> |
| 57 | + %17 = linalg.generic { |
| 58 | + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 59 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], |
| 60 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 61 | + ins(%14#2 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { |
| 62 | + ^bb0(%in: f32, %out: f16): |
| 63 | + %18 = arith.truncf %in : f32 to f16 |
| 64 | + linalg.yield %18 : f16 |
| 65 | + } -> tensor<2x10x4096x64xf16> |
| 66 | + return %15, %16, %17 |
| 67 | + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16> |
| 68 | +} |
| 69 | +// CHECK-LABEL: func @fused_contraction_1 |
| 70 | +// CHECK-COUNT-24: amdgpu.mfma |
| 71 | + |
| 72 | +// ----- |
| 73 | + |
| 74 | +func.func @fused_contraction_2(%arg0: tensor<4096x640xf32>, |
| 75 | + %arg1 : tensor<640x640xf32>, %arg2 : tensor<640x640xf32>, |
| 76 | + %arg3 : tensor<640x640xf32>) |
| 77 | + -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { |
| 78 | + %11 = tensor.empty() : tensor<4096x640xf32> |
| 79 | + %12 = tensor.empty() : tensor<4096x640xf32> |
| 80 | + %cst = arith.constant 0.0: f32 |
| 81 | + %13 = linalg.fill ins(%cst : f32) |
| 82 | + outs(%12 : tensor<4096x640xf32>) -> tensor<4096x640xf32> |
| 83 | + %14:3 = linalg.generic { |
| 84 | + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, |
| 85 | + affine_map<(d0, d1, d2) -> (d2, d1)>, |
| 86 | + affine_map<(d0, d1, d2) -> (d2, d1)>, |
| 87 | + affine_map<(d0, d1, d2) -> (d2, d1)>, |
| 88 | + affine_map<(d0, d1, d2) -> (d0, d1)>, |
| 89 | + affine_map<(d0, d1, d2) -> (d0, d1)>, |
| 90 | + affine_map<(d0, d1, d2) -> (d0, d1)>], |
| 91 | + iterator_types = ["parallel", "parallel", "reduction"]} |
| 92 | + ins(%arg0, %arg1, %arg2, %arg3 |
| 93 | + : tensor<4096x640xf32>, tensor<640x640xf32>, tensor<640x640xf32>, |
| 94 | + tensor<640x640xf32>) |
| 95 | + outs(%13, %13, %13 |
| 96 | + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) { |
| 97 | + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f32, %out_3: f32, %out_4: f32): |
| 98 | + %20 = arith.mulf %in, %in_0 : f32 |
| 99 | + %21 = arith.addf %out, %20 : f32 |
| 100 | + %23 = arith.mulf %in, %in_1 : f32 |
| 101 | + %24 = arith.addf %out_3, %23 : f32 |
| 102 | + %26 = arith.mulf %in, %in_2 : f32 |
| 103 | + %27 = arith.addf %out_4, %26 : f32 |
| 104 | + linalg.yield %21, %24, %27 : f32, f32, f32 |
| 105 | + } -> (tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32>) |
| 106 | + return %14#0, %14#1, %14#2 |
| 107 | + : tensor<4096x640xf32>, tensor<4096x640xf32>, tensor<4096x640xf32> |
| 108 | +} |
| 109 | +// CHECK-LABEL: func @fused_contraction_2 |
| 110 | +// CHECK-COUNT-24: amdgpu.mfma |
| 111 | + |
| 112 | +// ----- |
| 113 | + |
| 114 | +func.func @fused_contraction_3(%arg0 : tensor<2x4096x640xi8>, |
| 115 | + %arg1 : tensor<2x640x640xi8>, %arg2 : tensor<2x640x640xi8>) |
| 116 | + -> (tensor<2x4096x640xf16>, tensor<2x4096x640xf16>) { |
| 117 | + %c0_i32 = arith.constant 0 : i32 |
| 118 | + %18 = tensor.empty() : tensor<2x4096x640xf16> |
| 119 | + %19 = tensor.empty() : tensor<2x4096x640xi32> |
| 120 | + %20 = linalg.fill ins(%c0_i32 : i32) |
| 121 | + outs(%19 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32> |
| 122 | + %21:2 = linalg.generic { |
| 123 | + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, |
| 124 | + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, |
| 125 | + affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, |
| 126 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, |
| 127 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], |
| 128 | + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} |
| 129 | + ins(%arg0, %arg1, %arg2 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>, tensor<2x640x640xi8>) |
| 130 | + outs(%20, %20 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) { |
| 131 | + ^bb0(%in: i8, %in_0: i8, %in_1: i8, %out: i32, %out_2: i32): |
| 132 | + %24 = arith.extsi %in : i8 to i32 |
| 133 | + %25 = arith.extsi %in_0 : i8 to i32 |
| 134 | + %26 = arith.muli %24, %25 : i32 |
| 135 | + %27 = arith.addi %out, %26 : i32 |
| 136 | + %28 = arith.extsi %in_1 : i8 to i32 |
| 137 | + %29 = arith.muli %24, %28 : i32 |
| 138 | + %30 = arith.addi %out_2, %29 : i32 |
| 139 | + linalg.yield %27, %30 : i32, i32 |
| 140 | + } -> (tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) |
| 141 | + %22 = linalg.generic { |
| 142 | + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
| 143 | + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], |
| 144 | + iterator_types = ["parallel", "parallel", "parallel"]} |
| 145 | + ins(%21#0 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { |
| 146 | + ^bb0(%in: i32, %out: f16): |
| 147 | + %27 = arith.sitofp %in : i32 to f32 |
| 148 | + %29 = arith.truncf %27 : f32 to f16 |
| 149 | + linalg.yield %29 : f16 |
| 150 | + } -> tensor<2x4096x640xf16> |
| 151 | + %23 = linalg.generic { |
| 152 | + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
| 153 | + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], |
| 154 | + iterator_types = ["parallel", "parallel", "parallel"]} |
| 155 | + ins(%21#1 : tensor<2x4096x640xi32>) outs(%18 : tensor<2x4096x640xf16>) { |
| 156 | + ^bb0(%in: i32, %out: f16): |
| 157 | + %27 = arith.sitofp %in : i32 to f32 |
| 158 | + %29 = arith.truncf %27 : f32 to f16 |
| 159 | + linalg.yield %29 : f16 |
| 160 | + } -> tensor<2x4096x640xf16> |
| 161 | + return %22, %23 : tensor<2x4096x640xf16>, tensor<2x4096x640xf16> |
| 162 | +} |
| 163 | +// CHECK-LABEL: func @fused_contraction_3 |
| 164 | +// CHECK-COUNT-24: amdgpu.mfma |
| 165 | + |
| 166 | +// ----- |
| 167 | + |
| 168 | +func.func @fused_contraction_4(%arg0: tensor<2x4096x640xf16>, |
| 169 | + %arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, |
| 170 | + %arg3 : tensor<10x64x640xf16>) |
| 171 | + -> (tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>) { |
| 172 | + %9 = tensor.empty() : tensor<2x10x64x4096xf16> |
| 173 | + %10 = tensor.empty() : tensor<2x10x64x4096xf32> |
| 174 | + %11 = tensor.empty() : tensor<2x10x4096x64xf16> |
| 175 | + %12 = tensor.empty() : tensor<2x10x4096x64xf32> |
| 176 | + %cst = arith.constant 0.0: f32 |
| 177 | + %fill0 = linalg.fill ins(%cst : f32) |
| 178 | + outs(%12 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32> |
| 179 | + %fill1 = linalg.fill ins(%cst : f32) |
| 180 | + outs(%10 : tensor<2x10x64x4096xf32>) -> tensor<2x10x64x4096xf32> |
| 181 | + %14:3 = linalg.generic { |
| 182 | + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, |
| 183 | + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, |
| 184 | + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, |
| 185 | + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, |
| 186 | + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, |
| 187 | + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, |
| 188 | + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2)>], |
| 189 | + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} |
| 190 | + ins(%arg0, %arg1, %arg2, %arg3 |
| 191 | + : tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>, |
| 192 | + tensor<10x64x640xf16>) |
| 193 | + outs(%fill0, %fill0, %fill1 |
| 194 | + : tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) { |
| 195 | + ^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32): |
| 196 | + %18 = arith.extf %in : f16 to f32 |
| 197 | + %19 = arith.extf %in_0 : f16 to f32 |
| 198 | + %20 = arith.mulf %18, %19 : f32 |
| 199 | + %21 = arith.addf %out, %20 : f32 |
| 200 | + %22 = arith.extf %in_1 : f16 to f32 |
| 201 | + %23 = arith.mulf %18, %22 : f32 |
| 202 | + %24 = arith.addf %out_3, %23 : f32 |
| 203 | + %25 = arith.extf %in_2 : f16 to f32 |
| 204 | + %26 = arith.mulf %18, %25 : f32 |
| 205 | + %27 = arith.addf %out_4, %26 : f32 |
| 206 | + linalg.yield %21, %24, %27 : f32, f32, f32 |
| 207 | + } -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) |
| 208 | + %15 = linalg.generic { |
| 209 | + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 210 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], |
| 211 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 212 | + ins(%14#0 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { |
| 213 | + ^bb0(%in: f32, %out: f16): |
| 214 | + %18 = arith.truncf %in : f32 to f16 |
| 215 | + linalg.yield %18 : f16 |
| 216 | + } -> tensor<2x10x4096x64xf16> |
| 217 | + %16 = linalg.generic { |
| 218 | + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 219 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], |
| 220 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 221 | + ins(%14#1 : tensor<2x10x4096x64xf32>) outs(%11 : tensor<2x10x4096x64xf16>) { |
| 222 | + ^bb0(%in: f32, %out: f16): |
| 223 | + %18 = arith.truncf %in : f32 to f16 |
| 224 | + linalg.yield %18 : f16 |
| 225 | + } -> tensor<2x10x4096x64xf16> |
| 226 | + %17 = linalg.generic { |
| 227 | + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, |
| 228 | + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], |
| 229 | + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 230 | + ins(%14#2 : tensor<2x10x64x4096xf32>) outs(%9 : tensor<2x10x64x4096xf16>) { |
| 231 | + ^bb0(%in: f32, %out: f16): |
| 232 | + %18 = arith.truncf %in : f32 to f16 |
| 233 | + linalg.yield %18 : f16 |
| 234 | + } -> tensor<2x10x64x4096xf16> |
| 235 | + return %15, %16, %17 |
| 236 | + : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16> |
| 237 | +} |
| 238 | +// CHECK-LABEL: func @fused_contraction_4 |
| 239 | +// CHECK-COUNT-24: amdgpu.mfma |
0 commit comments