|
1 | | -// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-promote-matmul-operands))" | FileCheck %s |
| 1 | +// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-promote-matmul-operands),canonicalize)" | FileCheck %s |
2 | 2 |
|
3 | 3 | #lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}> |
4 | 4 |
|
@@ -214,8 +214,88 @@ func.func @promote_with_cache_swizzle(%a: tensor<2x34x34x128xf32>, %b: tensor<2x |
214 | 214 | // CHECK-LABEL: func.func @promote_with_cache_swizzle |
215 | 215 | // CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<2x34x34x128xf32> |
216 | 216 | // CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<2x8x256xf32> |
217 | | -// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c128) |
218 | | -// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c256) |
| 217 | +// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c512) |
| 218 | +// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c1024) |
| 219 | +// CHECK: %[[PA:.+]] = iree_linalg_ext.im2col |
| 220 | +// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config |
| 221 | +// CHECK-SAME: ins(%[[SWIZZLE_A]] |
| 222 | +// CHECK: %[[PB:.+]] = linalg.copy |
| 223 | +// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma |
| 224 | +// CHECK-SAME: ins(%[[SWIZZLE_B]] |
| 225 | +// CHECK: linalg.batch_matmul {{.*}} ins(%[[PA]], %[[PB]] |
| 226 | + |
| 227 | + |
| 228 | +// ----- |
| 229 | + |
| 230 | +#lowering_config = #iree_gpu.lowering_config<{ |
| 231 | + promote_operands = [0, 1], |
| 232 | + promotion_types = [ |
| 233 | + #iree_gpu.promote_with_cache_swizzle<#iree_gpu.derived_thread_config>, |
| 234 | + #iree_gpu.promote_with_cache_swizzle<#iree_gpu.use_global_load_dma>]}> |
| 235 | + |
| 236 | +func.func @promote_with_cache_swizzle_f4(%a: tensor<2x34x34x128xf4E2M1FN>, %b: tensor<2x8x256xf4E2M1FN>) -> tensor<2x128x256xf32> { |
| 237 | + %cst = arith.constant 0.000000e+00 : f32 |
| 238 | + %empty = tensor.empty() : tensor<2x128x256xf32> |
| 239 | + %im2col_empty = tensor.empty() : tensor<2x128x8xf4E2M1FN> |
| 240 | + |
| 241 | + %im2col = iree_linalg_ext.im2col |
| 242 | + strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] |
| 243 | + m_offset = [0] * [1] k_offset = [0] * [1] |
| 244 | + batch_pos = [0] m_pos = [2, 3] k_pos = [1] |
| 245 | + input_k_perm = [0, 1, 2] |
| 246 | + ins(%a : tensor<2x34x34x128xf4E2M1FN>) |
| 247 | + outs(%im2col_empty : tensor<2x128x8xf4E2M1FN>) -> tensor<2x128x8xf4E2M1FN> |
| 248 | + |
| 249 | + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<2x128x256xf32>) -> tensor<2x128x256xf32> |
| 250 | + %mm = linalg.batch_matmul {lowering_config = #lowering_config} |
| 251 | + ins(%im2col, %b : tensor<2x128x8xf4E2M1FN>, tensor<2x8x256xf4E2M1FN>) outs(%fill : tensor<2x128x256xf32>) -> tensor<2x128x256xf32> |
| 252 | + return %mm : tensor<2x128x256xf32> |
| 253 | +} |
| 254 | + |
| 255 | +// CHECK-LABEL: func.func @promote_with_cache_swizzle_f4 |
| 256 | +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<2x34x34x128xf4E2M1FN> |
| 257 | +// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<2x8x256xf4E2M1FN> |
| 258 | +// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c64) |
| 259 | +// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c128) |
| 260 | +// CHECK: %[[PA:.+]] = iree_linalg_ext.im2col |
| 261 | +// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config |
| 262 | +// CHECK-SAME: ins(%[[SWIZZLE_A]] |
| 263 | +// CHECK: %[[PB:.+]] = linalg.copy |
| 264 | +// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma |
| 265 | +// CHECK-SAME: ins(%[[SWIZZLE_B]] |
| 266 | +// CHECK: linalg.batch_matmul {{.*}} ins(%[[PA]], %[[PB]] |
| 267 | + |
| 268 | +// ----- |
| 269 | +#lowering_config = #iree_gpu.lowering_config<{ |
| 270 | + promote_operands = [0, 1], |
| 271 | + promotion_types = [ |
| 272 | + #iree_gpu.promote_with_cache_swizzle<#iree_gpu.derived_thread_config>, |
| 273 | + #iree_gpu.promote_with_cache_swizzle<#iree_gpu.use_global_load_dma>]}> |
| 274 | + |
| 275 | +func.func @promote_with_cache_swizzle_f4_no_stride(%a: tensor<2x34x34x129xf4E2M1FN>, %b: tensor<2x8x256xf4E2M1FN>) -> tensor<2x129x256xf32> { |
| 276 | + %cst = arith.constant 0.000000e+00 : f32 |
| 277 | + %empty = tensor.empty() : tensor<2x129x256xf32> |
| 278 | + %im2col_empty = tensor.empty() : tensor<2x129x8xf4E2M1FN> |
| 279 | + |
| 280 | + %im2col = iree_linalg_ext.im2col |
| 281 | + strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] |
| 282 | + m_offset = [0] * [1] k_offset = [0] * [1] |
| 283 | + batch_pos = [0] m_pos = [2, 3] k_pos = [1] |
| 284 | + input_k_perm = [0, 1, 2] |
| 285 | + ins(%a : tensor<2x34x34x129xf4E2M1FN>) |
| 286 | + outs(%im2col_empty : tensor<2x129x8xf4E2M1FN>) -> tensor<2x129x8xf4E2M1FN> |
| 287 | + |
| 288 | + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<2x129x256xf32>) -> tensor<2x129x256xf32> |
| 289 | + %mm = linalg.batch_matmul {lowering_config = #lowering_config} |
| 290 | + ins(%im2col, %b : tensor<2x129x8xf4E2M1FN>, tensor<2x8x256xf4E2M1FN>) outs(%fill : tensor<2x129x256xf32>) -> tensor<2x129x256xf32> |
| 291 | + return %mm : tensor<2x129x256xf32> |
| 292 | +} |
| 293 | + |
| 294 | +// CHECK-LABEL: func.func @promote_with_cache_swizzle_f4_no_stride |
| 295 | +// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<2x34x34x129xf4E2M1FN> |
| 296 | +// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<2x8x256xf4E2M1FN> |
| 297 | +// CHECK-DAG: %[[SWIZZLE_A:.+]] = iree_gpu.buffer_resource_cast %[[A]] cacheSwizzleStride(%c0) |
| 298 | +// CHECK-DAG: %[[SWIZZLE_B:.+]] = iree_gpu.buffer_resource_cast %[[B]] cacheSwizzleStride(%c128) |
219 | 299 | // CHECK: %[[PA:.+]] = iree_linalg_ext.im2col |
220 | 300 | // CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config |
221 | 301 | // CHECK-SAME: ins(%[[SWIZZLE_A]] |
|
0 commit comments