Skip to content

Commit fef3160

Browse files
committed
Add check for order attribute in Xetile.init_tile verifier (#730)
1 parent 48d6723 commit fef3160

File tree

2 files changed

+157
-1
lines changed

2 files changed

+157
-1
lines changed

lib/Dialect/XeTile/IR/XeTileOps.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
///
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "mlir/IR/AffineMap.h"
1516
#include "mlir/IR/Attributes.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
1718
#include "mlir/IR/BuiltinTypes.h"
@@ -95,6 +96,30 @@ parseOptionalAttrDict(mlir::OpAsmParser &parser, mlir::OperationState &result,
9596
return mlir::success();
9697
}
9798

99+
static bool isColumnMajor(mlir::AffineMap layoutMap) {
100+
if (layoutMap.getNumDims() != 2 || layoutMap.getNumResults() != 2) {
101+
return false;
102+
}
103+
104+
auto results = layoutMap.getResults();
105+
if (mlir::isa<mlir::AffineDimExpr>(results[0]) &&
106+
mlir::isa<mlir::AffineDimExpr>(results[1])) {
107+
auto dimExpr0 = mlir::cast<mlir::AffineDimExpr>(results[0]);
108+
auto dimExpr1 = mlir::cast<mlir::AffineDimExpr>(results[1]);
109+
return dimExpr0.getPosition() == 1 && dimExpr1.getPosition() == 0;
110+
}
111+
return false;
112+
}
113+
114+
static bool isConstantIndex(mlir::Value value) {
115+
return value.getDefiningOp<mlir::arith::ConstantOp>() != nullptr;
116+
}
117+
118+
static int64_t getConstantValue(mlir::Value value) {
119+
auto constOp = value.getDefiningOp<mlir::arith::ConstantOp>();
120+
return constOp.getValue().cast<mlir::IntegerAttr>().getInt();
121+
}
122+
98123
mlir::LogicalResult InitTileOp::verify() {
99124

100125
// number of offsets must be 2 because init_tile creates 2D tiles
@@ -134,6 +159,88 @@ mlir::LogicalResult InitTileOp::verify() {
134159
return emitOpError("address is used as source but dynamic strides argument "
135160
"is missing or it is not 2D");
136161

162+
// Check for order attribute
163+
bool row_major = true;
164+
bool col_major = false;
165+
auto tileTy = getType();
166+
auto order = tileTy.getOrder();
167+
if (order[0] == 0 && order[1] == 1) {
168+
col_major = true;
169+
row_major = false;
170+
}
171+
172+
if (isSourceMemRef() && sourceMemRefHasStaticShape()) {
173+
auto memrefType = getSourceType().dyn_cast<mlir::MemRefType>();
174+
175+
// Checks for memrefs with format: memref<[shape], strided<[strides],
176+
// offsets:[offset]>>
177+
llvm::SmallVector<int64_t, 4> strides;
178+
auto shape = getSourceMemrefStaticShape();
179+
int64_t offset;
180+
if (mlir::succeeded(
181+
mlir::getStridesAndOffset(memrefType, strides, offset))) {
182+
if (row_major && !((strides[0] == shape[1]) && (strides[1] == 1)))
183+
return emitOpError(
184+
"memref operand is expected to have a row-major layout");
185+
186+
if (col_major && !((strides[0] == 1) && (strides[1] == shape[0])))
187+
return emitOpError(
188+
"memref operand is expected to have a column-major layout");
189+
return mlir::success();
190+
}
191+
192+
// Checks for memrefs with affine maps : memref<[shape], affine_map<(d0, d1)
193+
// -> (d1, d0)>>
194+
if (row_major && !(memrefType.getLayout().isIdentity())) {
195+
// No affine map means it's using the default row-major layout
196+
return emitOpError(
197+
"memref operand is expected to have a row-major layout");
198+
}
199+
200+
if (col_major) {
201+
auto layoutAttr = memrefType.getLayout().dyn_cast<mlir::AffineMapAttr>();
202+
if (!layoutAttr) {
203+
return emitOpError("expected a valid affine map in the layout");
204+
}
205+
mlir::AffineMap layoutMap = layoutAttr.getValue();
206+
207+
if (!isColumnMajor(layoutMap)) {
208+
return emitOpError(
209+
"memref operand is expected to have a column-major layout");
210+
}
211+
}
212+
} else if (isSourceInteger()) {
213+
auto dynamicShape = getDynamicShape();
214+
auto dynamicStrides = getDynamicStrides();
215+
216+
if (dynamicShape.size() == 0 || dynamicStrides.size() == 0) {
217+
return emitOpError("dynamic shape and strides must not be empty");
218+
}
219+
220+
// Check if all shape and stride values are constant.
221+
if (!llvm::all_of(dynamicShape, isConstantIndex) ||
222+
!llvm::all_of(dynamicStrides, isConstantIndex)) {
223+
llvm::dbgs() << "Assuming user has verified the layout\n";
224+
return mlir::success();
225+
}
226+
227+
auto shapeDim1 = getConstantValue(dynamicShape[1]);
228+
auto strideDim0 = getConstantValue(dynamicStrides[0]);
229+
auto strideDim1 = getConstantValue(dynamicStrides[1]);
230+
231+
// checks for layouts where source is not memref and just an address
232+
if (row_major && (strideDim0 == 1 && strideDim1 == shapeDim1)) {
233+
return emitOpError(
234+
"memref operand is expected to have a row-major layout");
235+
}
236+
237+
if (col_major && !(strideDim0 == 1 && strideDim1 == shapeDim1)) {
238+
return emitOpError(
239+
"memref operand is expected to have a column-major layout");
240+
}
241+
} else if (isSourceMemRef() && !sourceMemRefHasStaticShape())
242+
llvm::dbgs() << "Assuming user has verified the layout\n";
243+
137244
return mlir::success();
138245
}
139246

test/Dialect/XeTile/IR/invalid.mlir

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// RUN: imex-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
22

3-
43
// -----
54
func.func @init_tile_with_invalid_offsets(%source : memref<64x64xf32>, %offset : index) {
65
// the offsets of the init_tile must be 2D
@@ -9,6 +8,56 @@ func.func @init_tile_with_invalid_offsets(%source : memref<64x64xf32>, %offset :
98
: memref<64x64xf32> -> !xetile.tile<8x8xf32>
109
}
1110

11+
// -----
12+
func.func @test_init_tile_invalid_order(%src: memref<1024x1024xf16>) {
13+
// Memref is row major but tile is column major
14+
// expected-error@+1 {{memref operand is expected to have a column-major layout}}
15+
%1 = xetile.init_tile %src[8, 16] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, #xetile.tile_attr<order = [0, 1]>>
16+
return
17+
}
18+
19+
// -----
20+
func.func @test_init_tile_with_invalid_order(%a: memref<1024x1024xf16, affine_map<(d0, d1) -> (d1, d0)>>) {
21+
// Memref is column major but tile is row major
22+
// expected-error@+1 {{memref operand is expected to have a row-major layout}}
23+
%1 = xetile.init_tile %a[8, 16] : memref<1024x1024xf16, affine_map<(d0, d1) -> (d1, d0)>> -> !xetile.tile<32x64xf16>
24+
return
25+
}
26+
27+
// -----
28+
func.func @test_init_tile_with_invalid_strided_layout(%a: memref<512x1024xf16, strided<[1, 256], offset: ?>>) {
29+
// Memref is column major but tile is row major
30+
// expected-error@+1 {{memref operand is expected to have a row-major layout}}
31+
%1 = xetile.init_tile %a[8, 16] : memref<512x1024xf16, strided<[1, 256], offset: ?>> -> !xetile.tile<32x64xf16>
32+
return
33+
}
34+
35+
// -----
36+
func.func @test_init_tile_invalid_order_using_address(%src : i64) {
37+
// Expected row major access
38+
%c1 = arith.constant 1 : index
39+
%c64 = arith.constant 64 : index
40+
%c256 = arith.constant 512 : index
41+
%c1024 = arith.constant 1024 : index
42+
// expected-error@+1 {{memref operand is expected to have a row-major layout}}
43+
%1 = xetile.init_tile %src[%c256, %c64], [%c1024, %c1024], [%c1, %c1024] : i64 -> !xetile.tile<32x64xf16, #xetile.tile_attr<order = [1, 0]>>
44+
return
45+
}
46+
47+
// -----
48+
func.func @test_init_tile_using_address(%src : i64) {
49+
// Expected column major access
50+
%c1 = arith.constant 1 : index
51+
%c2 = arith.constant 1 : index
52+
%c64 = arith.constant 64 : index
53+
%c256 = arith.constant 256 : index
54+
%c512 = arith.constant 512 : index
55+
%c1024 = arith.constant 1024 : index
56+
// expected-error@+1 {{memref operand is expected to have a column-major layout}}
57+
%1 = xetile.init_tile %src[%c256, %c64], [%c512, %c1024], [%c1024, %c1] : i64 -> !xetile.tile<32x64xf16, #xetile.tile_attr<order = [0, 1]>>
58+
return
59+
}
60+
1261
// -----
1362
func.func @init_tile_static_memref_with_invalid_dynamic_shape(%source : memref<1024x1024xf32>,
1463
%dim0_size : index, %dim1_size : index) {

0 commit comments

Comments
 (0)