|
16 | 16 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
17 | 17 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
18 | 18 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
| 19 | +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
19 | 20 | #include "mlir/Transforms/DialectConversion.h" |
20 | 21 |
|
21 | 22 | namespace mlir { |
@@ -314,13 +315,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> { |
314 | 315 | } |
315 | 316 | }; |
316 | 317 |
|
| 318 | +/// This pattern transforms vector.broadcast ops to work at subgroup level. |
| 319 | +/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData. |
| 320 | +struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> { |
| 321 | + using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern; |
| 322 | + |
| 323 | + LogicalResult |
| 324 | + matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, |
| 325 | + ConversionPatternRewriter &rewriter) const override { |
| 326 | + auto resultType = dyn_cast<VectorType>(op.getResult().getType()); |
| 327 | + if (!resultType) |
| 328 | + return rewriter.notifyMatchFailure(op, "Result is not a vector type"); |
| 329 | + |
| 330 | + // Only handle broadcasts to vectors with XeGPU layout attribute |
| 331 | + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); |
| 332 | + if (!layout || !layout.getSgLayout()) |
| 333 | + return rewriter.notifyMatchFailure( |
| 334 | + op, "Result does not have a valid layout attribute for subgroup distribution"); |
| 335 | + |
| 336 | + // Extract sgShape from layout |
| 337 | + SmallVector<int64_t> sgShape; |
| 338 | + if (auto sgDataAttr = layout.getSgData()) { |
| 339 | + sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); |
| 340 | + } else { |
| 341 | + auto sgLayoutArr = layout.getSgLayout(); |
| 342 | + ArrayRef<int64_t> shape = resultType.getShape(); |
| 343 | + sgShape.reserve(shape.size()); |
| 344 | + for (size_t i = 0; i < shape.size(); ++i) { |
| 345 | + assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero"); |
| 346 | + sgShape.push_back(shape[i] / sgLayoutArr[i]); |
| 347 | + } |
| 348 | + } |
| 349 | + |
| 350 | + VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); |
| 351 | + SmallVector<Value> newBroadcasts; |
| 352 | + |
| 353 | + // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup |
| 354 | + for (Value unused : adaptor.getOperands().front()) { |
| 355 | + // All subgroups get the same broadcasted value |
| 356 | + auto newBroadcast = rewriter.create<vector::BroadcastOp>( |
| 357 | + op.getLoc(), newResultType, adaptor.getOperands().front()[0]); |
| 358 | + xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData()); |
| 359 | + newBroadcasts.push_back(newBroadcast.getResult()); |
| 360 | + } |
| 361 | + |
| 362 | + rewriter.replaceOpWithMultiple(op, {newBroadcasts}); |
| 363 | + return success(); |
| 364 | + } |
| 365 | +}; |
| 366 | + |
317 | 367 | } // namespace |
318 | 368 |
|
319 | 369 | namespace mlir { |
320 | 370 | namespace xegpu { |
321 | 371 | void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { |
322 | 372 | patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, |
323 | | - WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>( |
| 373 | + WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, |
| 374 | + WgToSgVectorBroadcastOp>( |
324 | 375 | patterns.getContext()); |
325 | 376 | } |
326 | 377 | } // namespace xegpu |
@@ -369,6 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() { |
369 | 420 | return isLegal(layout); |
370 | 421 | }); |
371 | 422 |
|
| 423 | + target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool { |
| 424 | + auto resultType = dyn_cast<VectorType>(op.getResult().getType()); |
| 425 | + if (!resultType) |
| 426 | + return true; |
| 427 | + auto layout = xegpu::getLayoutAttr(op.getResult()); |
| 428 | + return isLegal(layout); |
| 429 | + }); |
| 430 | + |
372 | 431 | target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
373 | 432 |
|
374 | 433 | xegpu::populateXeGPUWgToSgDistributePatterns(patterns); |
|
0 commit comments