|
1 | 1 | // RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s |
2 | 2 |
|
3 | | -// CHECK-LABEL: @get_desc_op |
4 | | -func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { |
| 3 | +// CHECK-LABEL: @get_desc_op_a |
| 4 | +func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { |
| 5 | + %c32 = arith.constant 32 : index |
| 6 | + %c4096 = arith.constant 4096 : index |
| 7 | + %c0 = arith.constant 0 : index |
| 8 | + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> |
| 9 | + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> |
| 10 | + // expected-remark @below {{found desc op}} |
| 11 | + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> |
| 12 | + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> |
| 13 | + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { |
| 14 | + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> |
| 15 | + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> |
| 16 | + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> |
| 17 | + scf.yield %7 : vector<256x256xf16> |
| 18 | + } |
| 19 | + return |
| 20 | +} |
| 21 | + |
| 22 | +module attributes {transform.with_named_sequence} { |
| 23 | + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { |
| 24 | + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
| 25 | + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value |
| 26 | + %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op |
| 27 | + transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op |
| 28 | + transform.yield |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +// ----- |
| 33 | + |
| 34 | +// CHECK-LABEL: @get_desc_op_c |
| 35 | +func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { |
| 36 | + %c32 = arith.constant 32 : index |
| 37 | + %c4096 = arith.constant 4096 : index |
5 | 38 | %c0 = arith.constant 0 : index |
6 | | - %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> |
7 | | - %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> |
8 | 39 | // expected-remark @below {{found desc op}} |
9 | | - %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> |
10 | | - %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> |
11 | | - %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> |
12 | | - %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> |
13 | | - %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> |
| 40 | + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> |
| 41 | + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> |
| 42 | + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> |
| 43 | + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> |
| 44 | + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { |
| 45 | + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> |
| 46 | + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> |
| 47 | + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> |
| 48 | + scf.yield %7 : vector<256x256xf16> |
| 49 | + } |
14 | 50 | return |
15 | 51 | } |
16 | 52 |
|
17 | 53 | module attributes {transform.with_named_sequence} { |
18 | 54 | transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { |
19 | 55 | %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op |
20 | | - %1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value |
| 56 | + %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value |
21 | 57 | %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op |
22 | 58 | transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op |
23 | 59 | transform.yield |
24 | 60 | } |
25 | 61 | } |
26 | 62 |
|
27 | 63 | // ----- |
| 64 | + |
28 | 65 | // CHECK-LABEL: @set_desc_layout |
29 | 66 | func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) { |
30 | 67 | // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 |
|
0 commit comments