|
12 | 12 | ///
|
13 | 13 | //===----------------------------------------------------------------------===//
|
14 | 14 |
|
| 15 | +#include "mlir/IR/AffineMap.h" |
15 | 16 | #include "mlir/IR/Attributes.h"
|
16 | 17 | #include "mlir/IR/BuiltinAttributes.h"
|
17 | 18 | #include "mlir/IR/BuiltinTypes.h"
|
@@ -95,6 +96,30 @@ parseOptionalAttrDict(mlir::OpAsmParser &parser, mlir::OperationState &result,
|
95 | 96 | return mlir::success();
|
96 | 97 | }
|
97 | 98 |
|
| 99 | +static bool isColumnMajor(mlir::AffineMap layoutMap) { |
| 100 | + if (layoutMap.getNumDims() != 2 || layoutMap.getNumResults() != 2) { |
| 101 | + return false; |
| 102 | + } |
| 103 | + |
| 104 | + auto results = layoutMap.getResults(); |
| 105 | + if (mlir::isa<mlir::AffineDimExpr>(results[0]) && |
| 106 | + mlir::isa<mlir::AffineDimExpr>(results[1])) { |
| 107 | + auto dimExpr0 = mlir::cast<mlir::AffineDimExpr>(results[0]); |
| 108 | + auto dimExpr1 = mlir::cast<mlir::AffineDimExpr>(results[1]); |
| 109 | + return dimExpr0.getPosition() == 1 && dimExpr1.getPosition() == 0; |
| 110 | + } |
| 111 | + return false; |
| 112 | +} |
| 113 | + |
| 114 | +static bool isConstantIndex(mlir::Value value) { |
| 115 | + return value.getDefiningOp<mlir::arith::ConstantOp>() != nullptr; |
| 116 | +} |
| 117 | + |
| 118 | +static int64_t getConstantValue(mlir::Value value) { |
| 119 | + auto constOp = value.getDefiningOp<mlir::arith::ConstantOp>(); |
| 120 | + return constOp.getValue().cast<mlir::IntegerAttr>().getInt(); |
| 121 | +} |
| 122 | + |
98 | 123 | mlir::LogicalResult InitTileOp::verify() {
|
99 | 124 |
|
100 | 125 | // number of offsets must be 2 because init_tile creates 2D tiles
|
@@ -134,6 +159,88 @@ mlir::LogicalResult InitTileOp::verify() {
|
134 | 159 | return emitOpError("address is used as source but dynamic strides argument "
|
135 | 160 | "is missing or it is not 2D");
|
136 | 161 |
|
| 162 | + // Check for order attribute |
| 163 | + bool row_major = true; |
| 164 | + bool col_major = false; |
| 165 | + auto tileTy = getType(); |
| 166 | + auto order = tileTy.getOrder(); |
| 167 | + if (order[0] == 0 && order[1] == 1) { |
| 168 | + col_major = true; |
| 169 | + row_major = false; |
| 170 | + } |
| 171 | + |
| 172 | + if (isSourceMemRef() && sourceMemRefHasStaticShape()) { |
| 173 | + auto memrefType = getSourceType().dyn_cast<mlir::MemRefType>(); |
| 174 | + |
| 175 | + // Checks for memrefs with format: memref<[shape], strided<[strides], |
| 176 | + // offsets:[offset]>> |
| 177 | + llvm::SmallVector<int64_t, 4> strides; |
| 178 | + auto shape = getSourceMemrefStaticShape(); |
| 179 | + int64_t offset; |
| 180 | + if (mlir::succeeded( |
| 181 | + mlir::getStridesAndOffset(memrefType, strides, offset))) { |
| 182 | + if (row_major && !((strides[0] == shape[1]) && (strides[1] == 1))) |
| 183 | + return emitOpError( |
| 184 | + "memref operand is expected to have a row-major layout"); |
| 185 | + |
| 186 | + if (col_major && !((strides[0] == 1) && (strides[1] == shape[0]))) |
| 187 | + return emitOpError( |
| 188 | + "memref operand is expected to have a column-major layout"); |
| 189 | + return mlir::success(); |
| 190 | + } |
| 191 | + |
| 192 | + // Checks for memrefs with affine maps : memref<[shape], affine_map<(d0, d1) |
| 193 | + // -> (d1, d0)>> |
| 194 | + if (row_major && !(memrefType.getLayout().isIdentity())) { |
| 195 | + // No affine map means it's using the default row-major layout |
| 196 | + return emitOpError( |
| 197 | + "memref operand is expected to have a row-major layout"); |
| 198 | + } |
| 199 | + |
| 200 | + if (col_major) { |
| 201 | + auto layoutAttr = memrefType.getLayout().dyn_cast<mlir::AffineMapAttr>(); |
| 202 | + if (!layoutAttr) { |
| 203 | + return emitOpError("expected a valid affine map in the layout"); |
| 204 | + } |
| 205 | + mlir::AffineMap layoutMap = layoutAttr.getValue(); |
| 206 | + |
| 207 | + if (!isColumnMajor(layoutMap)) { |
| 208 | + return emitOpError( |
| 209 | + "memref operand is expected to have a column-major layout"); |
| 210 | + } |
| 211 | + } |
| 212 | + } else if (isSourceInteger()) { |
| 213 | + auto dynamicShape = getDynamicShape(); |
| 214 | + auto dynamicStrides = getDynamicStrides(); |
| 215 | + |
| 216 | + if (dynamicShape.size() == 0 || dynamicStrides.size() == 0) { |
| 217 | + return emitOpError("dynamic shape and strides must not be empty"); |
| 218 | + } |
| 219 | + |
| 220 | + // Check if all shape and stride values are constant. |
| 221 | + if (!llvm::all_of(dynamicShape, isConstantIndex) || |
| 222 | + !llvm::all_of(dynamicStrides, isConstantIndex)) { |
| 223 | + llvm::dbgs() << "Assuming user has verified the layout\n"; |
| 224 | + return mlir::success(); |
| 225 | + } |
| 226 | + |
| 227 | + auto shapeDim1 = getConstantValue(dynamicShape[1]); |
| 228 | + auto strideDim0 = getConstantValue(dynamicStrides[0]); |
| 229 | + auto strideDim1 = getConstantValue(dynamicStrides[1]); |
| 230 | + |
| 231 | + // checks for layouts where source is not memref and just an address |
| 232 | + if (row_major && (strideDim0 == 1 && strideDim1 == shapeDim1)) { |
| 233 | + return emitOpError( |
| 234 | + "memref operand is expected to have a row-major layout"); |
| 235 | + } |
| 236 | + |
| 237 | + if (col_major && !(strideDim0 == 1 && strideDim1 == shapeDim1)) { |
| 238 | + return emitOpError( |
| 239 | + "memref operand is expected to have a column-major layout"); |
| 240 | + } |
| 241 | + } else if (isSourceMemRef() && !sourceMemRefHasStaticShape()) |
| 242 | + llvm::dbgs() << "Assuming user has verified the layout\n"; |
| 243 | + |
137 | 244 | return mlir::success();
|
138 | 245 | }
|
139 | 246 |
|
|
0 commit comments