|
11 | 11 | //===----------------------------------------------------------------------===// |
12 | 12 |
|
13 | 13 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| 14 | +#include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 15 | +#include "mlir/Dialect/Utils/IndexingUtils.h" |
14 | 16 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
15 | 17 | #include "mlir/IR/Operation.h" |
16 | 18 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 19 | +#include "mlir/Transforms/DialectConversion.h" |
17 | 20 | #include "llvm/Support/FormatVariadic.h" |
18 | 21 | #include <cstdint> |
19 | 22 | #include <numeric> |
@@ -127,3 +130,182 @@ std::string xegpu::getLayoutName(OpResult res) { |
127 | 130 | const StringRef prefix = "layout_result_"; |
128 | 131 | return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str(); |
129 | 132 | } |
| 133 | + |
| 134 | +void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) { |
| 135 | + MLIRContext *context = op->getContext(); |
| 136 | + |
| 137 | + auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs, |
| 138 | + Location loc) -> Value { |
| 139 | + return builder.create<UnrealizedConversionCastOp>(loc, type, inputs) |
| 140 | + .getResult(0); |
| 141 | + }; |
| 142 | + |
| 143 | + { // convert VectorType to RankedTensorType for SCF Structural ops |
| 144 | + TypeConverter converter; |
| 145 | + converter.addConversion([&](Type type) -> Type { return type; }); |
| 146 | + converter.addConversion([&](VectorType type) -> Type { |
| 147 | + return RankedTensorType::get(type.getShape(), type.getElementType()); |
| 148 | + }); |
| 149 | + converter.addSourceMaterialization(materializeCast); |
| 150 | + converter.addTargetMaterialization(materializeCast); |
| 151 | + |
| 152 | + mlir::ConversionTarget target(*context); |
| 153 | + target.addLegalOp<UnrealizedConversionCastOp>(); |
| 154 | + |
| 155 | + mlir::RewritePatternSet patterns(context); |
| 156 | + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, |
| 157 | + target); |
| 158 | + (void)mlir::applyPartialConversion(op, target, std::move(patterns)); |
| 159 | + } |
| 160 | + |
| 161 | + { // propagate the layout attribute to RankedTensorType by checking |
| 162 | + // BuiltInUnrealizedCastOps |
| 163 | + // for VectorType to RankedTensorType cast. |
| 164 | + op->walk([&](UnrealizedConversionCastOp castOp) { |
| 165 | + if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1) |
| 166 | + return WalkResult::skip(); |
| 167 | + |
| 168 | + Value input = castOp.getInputs()[0]; |
| 169 | + Value result = castOp.getResults()[0]; |
| 170 | + auto inputTy = dyn_cast<VectorType>(input.getType()); |
| 171 | + auto resultTy = dyn_cast<RankedTensorType>(result.getType()); |
| 172 | + |
| 173 | + // Only look at ops casting from VectorType to RankedTensorType |
| 174 | + if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy)) |
| 175 | + return WalkResult::skip(); |
| 176 | + |
| 177 | + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); |
| 178 | + if (!layout) |
| 179 | + return WalkResult::skip(); |
| 180 | + |
| 181 | + RankedTensorType newTy = resultTy.cloneWithEncoding(layout); |
| 182 | + result.setType(newTy); |
| 183 | + |
| 184 | + // update the arguments if user is a LoopLike op. |
| 185 | + for (OpOperand &use : result.getUses()) { |
| 186 | + if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) { |
| 187 | + BlockArgument arg = loop.getTiedLoopRegionIterArg(&use); |
| 188 | + arg.setType(newTy); |
| 189 | + } |
| 190 | + // whileOp has two regions, the BlockArgument of the after region |
| 191 | + // is not exposed by LoopLikeOpInterface |
| 192 | + if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) { |
| 193 | + unsigned idx = use.getOperandNumber(); |
| 194 | + BlockArgument arg = whileOp.getAfterArguments()[idx]; |
| 195 | + arg.setType(newTy); |
| 196 | + } |
| 197 | + } |
| 198 | + return WalkResult::advance(); |
| 199 | + }); |
| 200 | + |
| 201 | + // using yieldOp as anchor to update the result type of its ParentOp |
| 202 | + op->walk([&](scf::YieldOp yieldOp) { |
| 203 | + Operation *parentOp = yieldOp->getParentOp(); |
| 204 | + for (OpResult r : parentOp->getOpResults()) { |
| 205 | + unsigned idx = r.getResultNumber(); |
| 206 | + Type resultTy = r.getType(); |
| 207 | + Type yieldTy = yieldOp.getResults()[idx].getType(); |
| 208 | + if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy) |
| 209 | + r.setType(yieldTy); |
| 210 | + } |
| 211 | + }); |
| 212 | + } |
| 213 | + |
| 214 | + { // perform the conversion from RankedTensorType to VectorType based on the |
| 215 | + // LayoutAttr |
| 216 | + |
| 217 | + auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape, |
| 218 | + DenseI32ArrayAttr sgDataAttr, |
| 219 | + DenseI32ArrayAttr sgLayoutAttr) { |
| 220 | + SmallVector<int64_t> tileShape; |
| 221 | + auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); |
| 222 | + if (sgDataAttr) |
| 223 | + tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); |
| 224 | + else |
| 225 | + tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape); |
| 226 | + assert(tileShape.size() && "failed to compute tileShape"); |
| 227 | + SmallVector<int64_t> distUnit = |
| 228 | + computeElementwiseMul(sgLayout, tileShape); |
| 229 | + int count = computeProduct(shape) / computeProduct(distUnit); |
| 230 | + return std::make_pair(tileShape, count); |
| 231 | + }; |
| 232 | + |
| 233 | + TypeConverter converter; |
| 234 | + converter.addConversion([&](Type type) -> Type { return type; }); |
| 235 | + converter.addConversion( |
| 236 | + [&](RankedTensorType type, |
| 237 | + SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { |
| 238 | + ArrayRef<int64_t> shape = type.getShape(); |
| 239 | + auto encoding = type.getEncoding(); |
| 240 | + Type elemTy = type.getElementType(); |
| 241 | + |
| 242 | + // init count and subShape to the default value. If the LayoutAttr |
| 243 | + // is not present, it will return a VectorType with original shape. |
| 244 | + int count = 1; |
| 245 | + SmallVector<int64_t> subShape(shape); |
| 246 | + |
| 247 | + if (auto layout = |
| 248 | + llvm::dyn_cast_if_present<xegpu::LayoutAttr>(encoding)) { |
| 249 | + if (layout.isWgLayout()) { |
| 250 | + // for WgToSg, the subShape is either from sgData or computed as |
| 251 | + // shape/sgLayout |
| 252 | + std::tie(subShape, count) = computeTileShapeAndCount( |
| 253 | + shape, layout.getSgData(), layout.getSgLayout()); |
| 254 | + } else if (DenseI32ArrayAttr instData = layout.getInstData()) { |
| 255 | + // for unrolling, the subShape is determined by inst_data |
| 256 | + subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef()); |
| 257 | + count = computeProduct(shape) / computeProduct(subShape); |
| 258 | + } |
| 259 | + } |
| 260 | + auto newTy = VectorType::get(subShape, elemTy); |
| 261 | + result.append(count, newTy); |
| 262 | + return success(); |
| 263 | + }); |
| 264 | + |
| 265 | + converter.addConversion( |
| 266 | + [&](xegpu::TensorDescType type, |
| 267 | + SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { |
| 268 | + MLIRContext *ctx = type.getContext(); |
| 269 | + Type elemTy = type.getElementType(); |
| 270 | + Attribute encoding = type.getEncoding(); |
| 271 | + ArrayRef<int64_t> shape = type.getShape(); |
| 272 | + |
| 273 | + // init count and newTy to the default value. If the layout attribute |
| 274 | + // is not present, it will return the original type. |
| 275 | + int count = 1; |
| 276 | + Type newTy = type; |
| 277 | + |
| 278 | + if (xegpu::LayoutAttr layout = type.getLayoutAttr()) { |
| 279 | + SmallVector<int64_t> subShape, distUnit; |
| 280 | + if (layout.isWgLayout()) { |
| 281 | + // for WgToSg, the subShape is either from sgData or computed as |
| 282 | + // shape/sgLayout |
| 283 | + std::tie(subShape, count) = computeTileShapeAndCount( |
| 284 | + shape, layout.getSgData(), layout.getSgLayout()); |
| 285 | + layout = layout.dropSgLayoutAndData(); |
| 286 | + } else if (DenseI32ArrayAttr instData = layout.getInstData()) { |
| 287 | + // for unrolling, the subShape is determined by inst_data |
| 288 | + subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef()); |
| 289 | + count = computeProduct(shape) / computeProduct(subShape); |
| 290 | + layout = layout.dropInstData(); |
| 291 | + } |
| 292 | + newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding, |
| 293 | + layout); |
| 294 | + } |
| 295 | + |
| 296 | + result.append(count, newTy); |
| 297 | + return success(); |
| 298 | + }); |
| 299 | + |
| 300 | + converter.addSourceMaterialization(materializeCast); |
| 301 | + converter.addTargetMaterialization(materializeCast); |
| 302 | + |
| 303 | + mlir::ConversionTarget target(*context); |
| 304 | + target.addLegalOp<UnrealizedConversionCastOp>(); |
| 305 | + |
| 306 | + mlir::RewritePatternSet patterns(context); |
| 307 | + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, |
| 308 | + target); |
| 309 | + (void)mlir::applyPartialConversion(op, target, std::move(patterns)); |
| 310 | + } |
| 311 | +} |
0 commit comments