Skip to content

Commit 1762be9

Browse files
authored
Add transpose, reduce and broadcast definitions to xetile (#747)
1 parent e58993b commit 1762be9

File tree

7 files changed

+133
-4
lines changed

7 files changed

+133
-4
lines changed

include/imex/Dialect/XeTile/IR/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ add_mlir_dialect(XeTileOps xetile)
22
add_mlir_doc(XeTileOps XeTileDialect Dialects/ -gen-dialect-doc -dialect=xetile)
33

44
set(LLVM_TARGET_DEFINITIONS XeTileOps.td)
5-
mlir_tablegen(XeTileOpsAttrs.h.inc -gen-attrdef-decls)
6-
mlir_tablegen(XeTileOpsAttrs.cpp.inc -gen-attrdef-defs)
5+
mlir_tablegen(XeTileOpsAttrs.h.inc -gen-attrdef-decls --attrdefs-dialect=xetile)
6+
mlir_tablegen(XeTileOpsAttrs.cpp.inc -gen-attrdef-defs --attrdefs-dialect=xetile)
7+
8+
set(LLVM_TARGET_DEFINITIONS XeTileAttrs.td)
79
mlir_tablegen(XeTileOpsEnums.h.inc -gen-enum-decls)
810
mlir_tablegen(XeTileOpsEnums.cpp.inc -gen-enum-defs)
911
add_public_tablegen_target(XeTileOpsAttrsIncGen)

include/imex/Dialect/XeTile/IR/XeTileDialect.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def XeTile_Dialect : Dialect {
4949
let cppNamespace = "::imex::xetile";
5050

5151
let dependentDialects = [
52-
"::mlir::memref::MemRefDialect"];
52+
"::mlir::memref::MemRefDialect",
53+
"::mlir::vector::VectorDialect"];
5354

5455
// TODO: temporary disable it.
5556
let useDefaultTypePrinterParser = true;

include/imex/Dialect/XeTile/IR/XeTileOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef _XETILE_OPS_H_INCLUDED_
1616
#define _XETILE_OPS_H_INCLUDED_
1717

18+
#include <mlir/Dialect/Vector/IR/VectorOps.h>
1819
#include <mlir/IR/BuiltinTypeInterfaces.h>
1920
#include <mlir/IR/BuiltinTypes.h>
2021
#include <mlir/IR/Dialect.h>

include/imex/Dialect/XeTile/IR/XeTileOps.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ include "imex/Dialect/XeTile/IR/XeTileDialect.td"
1818
include "imex/Dialect/XeTile/IR/XeTileTypes.td"
1919
include "imex/Dialect/XeTile/IR/XeTileAttrs.td"
2020

21+
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
22+
2123
// Base class for dialect operations. This operation inherits from the base
2224
// `Op` class in OpBase.td, and provides:
2325
// * The parent dialect of the operation.
@@ -519,5 +521,52 @@ def XeTile_AtomicRMWOp : XeTile_Op<"atomic_rmw", []> {
519521
}];
520522
}
521523

524+
def XeTile_TransposeOp: XeTile_Op<"transpose", []> {
525+
let summary = "transpose a 2D vector.";
526+
let description = [{
527+
It has the same semantic with `vector.transpose`, but limits the vector to be 2D.
528+
}];
529+
530+
let arguments = (ins XeTile_2DVector: $source,
531+
DenseI64ArrayAttr:$permutation);
532+
let results = (outs XeTile_2DVector: $result);
533+
let assemblyFormat = [{
534+
$source $permutation attr-dict `:` type($source) `->` type($result)
535+
}];
536+
let hasVerifier = 1;
537+
}
538+
539+
def XeTile_ReduceOp: XeTile_Op<"reduce", []> {
540+
let summary = "performs a reduction operation over a 2D vector.";
541+
let description = [{
542+
It has the same semantics as the `vector.multi_reduction`,
543+
but restricts the vector dimension to 2D, and also the result
544+
is 2D too, with the reduced axis being 1.
545+
}];
546+
547+
let arguments = (ins Vector_CombiningKindAttr: $kind,
548+
XeTile_2DVector: $source,
549+
DenseI64ArrayAttr: $reduction_dim);
550+
let results = (outs XeTile_2DVector: $result);
551+
let assemblyFormat = [{
552+
$kind `,` $source $reduction_dim attr-dict `:` type($source) `->` type($result)
553+
}];
554+
555+
let hasVerifier = 1;
556+
}
557+
558+
def XeTile_BroadCastOp: XeTile_Op<"broadcast", []> {
559+
let summary = "broadcast a vector from 1D to 2D.";
560+
561+
let arguments = (ins XeTile_2DVector: $source,
562+
DenseI64ArrayAttr: $broadcast_dim);
563+
let results = (outs XeTile_2DVector: $result);
564+
let assemblyFormat = [{
565+
$source $broadcast_dim attr-dict `:` type($source) `->` type($result)
566+
}];
567+
let hasVerifier = 1;
568+
}
569+
570+
522571

523572
#endif // _XETILE_OPS_TD_INCLUDED_

lib/Dialect/XeTile/IR/XeTileOps.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ mlir::LogicalResult InitTileOp::verify() {
170170
}
171171

172172
if (isSourceMemRef() && sourceMemRefHasStaticShape()) {
173-
auto memrefType = getSourceType().dyn_cast<mlir::MemRefType>();
173+
auto memrefType = mlir::dyn_cast<mlir::MemRefType>(getSourceType());
174174

175175
// Checks for memrefs with format: memref<[shape], strided<[strides],
176176
// offsets:[offset]>>
@@ -905,6 +905,42 @@ mlir::OpFoldResult TileUnpackOp::fold(FoldAdaptor /*adaptor*/) {
905905
return nullptr;
906906
}
907907

908+
mlir::LogicalResult TransposeOp::verify() {
909+
auto srcShape = getSource().getType().getShape();
910+
auto resShape = getResult().getType().getShape();
911+
auto permutation = getPermutation();
912+
913+
auto transposeShape = srcShape.vec();
914+
for (auto [i, j] : llvm::enumerate(permutation)) {
915+
if (j >= (int64_t)srcShape.size())
916+
return emitOpError("permutation index out of bounds");
917+
transposeShape[i] = srcShape[j];
918+
}
919+
920+
if (transposeShape != resShape.vec())
921+
return emitOpError("Incorrect transpose permutation");
922+
923+
return mlir::success();
924+
}
925+
926+
mlir::LogicalResult ReduceOp::verify() {
927+
auto dims = getReductionDim();
928+
auto resShape = getResult().getType().getShape();
929+
for (auto i : dims)
930+
if (resShape[i] != 1)
931+
return emitOpError("reduction dimension of result must have size 1");
932+
return mlir::success();
933+
}
934+
935+
mlir::LogicalResult BroadCastOp::verify() {
936+
auto dims = getBroadcastDim();
937+
auto srcShape = getSource().getType().getShape();
938+
for (auto i : dims)
939+
if (srcShape[i] != 1)
940+
return emitOpError("broadcast dimension of source must have size 1");
941+
return mlir::success();
942+
}
943+
908944
} // namespace xetile
909945
} // namespace imex
910946

test/Dialect/XeTile/IR/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,25 @@ func.func @tile_unpack_invalid_output_shape(%in : vector<4x4x16x16xf16>) {
228228
#wg_map_4 = #xetile.tile_attr<order = [0, 1, 2]>
229229
// expected-error@+1 {{expect integer array of size 2 for wg_data}}
230230
#wg_map_5 = #xetile.tile_attr<wg_data = [32, 64, 128]>
231+
232+
233+
// -----
234+
func.func @test_transpose(%source: vector<8x16xf16>) {
235+
// expected-error@+1 {{Incorrect transpose permutation}}
236+
%1 = xetile.transpose %source [0, 1] : vector<8x16xf16> -> vector<16x8xf16>
237+
return
238+
}
239+
240+
// -----
241+
func.func @test_reduce(%source: vector<8x16xf16>) {
242+
// expected-error@+1 {{reduction dimension of result must have size 1}}
243+
%1 = xetile.reduce <add>, %source [0] : vector<8x16xf16> -> vector<2x16xf16>
244+
return
245+
}
246+
247+
// -----
248+
func.func @test_broadcast(%source: vector<2x16xf16>) {
249+
// expected-error@+1 {{broadcast dimension of source must have size 1}}
250+
%1 = xetile.broadcast %source [0] : vector<2x16xf16> -> vector<8x16xf16>
251+
return
252+
}

test/Dialect/XeTile/IR/ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,21 @@ func.func @test_atomic_rmw(%tile : !xetile.tile<8x16xf16>, %value : vector<8x16x
309309
%1 = xetile.atomic_rmw "addf" %value, %tile : vector<8x16xf16>, !xetile.tile<8x16xf16> -> vector<8x16xf16>
310310
return
311311
}
312+
313+
func.func @test_transpose(%source: vector<8x16xf16>) {
314+
// CHECK: xetile.transpose {{.*}} [1, 0] : vector<8x16xf16> -> vector<16x8xf16>
315+
%1 = xetile.transpose %source [1, 0] : vector<8x16xf16> -> vector<16x8xf16>
316+
return
317+
}
318+
319+
func.func @test_reduce(%source: vector<8x16xf16>) {
320+
// CHECK: xetile.reduce {{.*}} [0] : vector<8x16xf16> -> vector<1x16xf16>
321+
%1 = xetile.reduce <add>, %source [0] : vector<8x16xf16> -> vector<1x16xf16>
322+
return
323+
}
324+
325+
func.func @test_broadcast(%source: vector<1x16xf16>) {
326+
// CHECK: xetile.broadcast {{.*}} [0] : vector<1x16xf16> -> vector<8x16xf16>
327+
%1 = xetile.broadcast %source [0] : vector<1x16xf16> -> vector<8x16xf16>
328+
return
329+
}

0 commit comments

Comments
 (0)