Skip to content

Commit 605eee0

Browse files
committed
refine
1 parent bf37af1 commit 605eee0

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1414
#include "mlir/Dialect/Index/IR/IndexOps.h"
1515
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1617
#include "mlir/Dialect/Utils/IndexingUtils.h"
1718
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1819
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
@@ -329,13 +330,12 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
329330

330331
// Handles UnrealizedConversionCastOp generated during
331332
// SCFStructuralTypeConversions (step 1). This op may appear as either a
332-
// target or source materialization for Vector or TensorDesc values, e.g.:
333+
// target or source materialization for Vector values, e.g.:
333334
// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
334335
// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
335-
// it could be either 1:1, 1:N or N:1 cast. In all cases, the pattern
336+
// it could be either 1:N or N:1 cast. In both cases, the pattern
336337
// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
337338
// TODO: remove it when context-aware type converter is ready.
338-
// It is safe only when input codes don't contain UnrealizedConversionCastOp.
339339
struct UnrealizedConversionCastOpPattern
340340
: public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
341341
using OpConversionPattern<
@@ -346,28 +346,30 @@ struct UnrealizedConversionCastOpPattern
346346
ConversionPatternRewriter &rewriter) const override {
347347
SmallVector<Value> inputs = xegpu::flattenValues(adaptor.getInputs());
348348

349-
auto inputTy = inputs[0].getType();
350-
auto outputTy = op->getOpResult(0).getType();
349+
auto inputTy = dyn_cast<VectorType>(inputs[0].getType());
350+
auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
351351

352-
if (!llvm::all_equal(op->getResultTypes()) ||
353-
!llvm::all_equal(ValueRange(inputs).getTypes()) ||
354-
!isa<VectorType, xegpu::TensorDescType>(inputTy) ||
355-
!isa<VectorType, xegpu::TensorDescType>(outputTy))
352+
if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
353+
!llvm::all_equal(ValueRange(inputs).getTypes()))
356354
return failure();
357355

358356
// Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
359357
// the input values provided by the adaptor should already be distributed,
360358
// and their types should correspond exactly to the result types of the
361359
// operation.
362-
if (op.getNumOperands() == 1) {
360+
if (op.getNumOperands() == 1 &&
361+
llvm::equal(ValueRange(inputs).getTypes(), op->getResultTypes())) {
363362
rewriter.replaceOp(op, inputs);
364363
return success();
365364
}
366365

367366
// Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
368367
// All input values must have the same vector type, and their shape must be
369368
// evenly divisible by the output vector's shape.
370-
if (op.getNumResults() == 1) {
369+
// TODO: it is not safe to do such forward, since such N:1 cast could be
370+
// from others
371+
if (op.getNumResults() == 1 &&
372+
computeShapeRatio(outputTy.getShape(), inputTy.getShape())) {
371373
rewriter.replaceOpWithMultiple(op, {inputs});
372374
return success();
373375
}
@@ -396,6 +398,7 @@ struct XeGPUWgToSgDistributePass
396398
} // namespace
397399

398400
void XeGPUWgToSgDistributePass::runOnOperation() {
401+
399402
TypeConverter converter;
400403
converter.addConversion([&](Type type) -> Type { return type; });
401404
converter.addConversion(
@@ -414,6 +417,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
414417
return success();
415418
});
416419

420+
// Step 1: Apply SCFStructuralTypeConversions to SCF operations with
421+
// VectorType operands. This first converts such operands to RankedTensorType,
422+
// propagates the layout attribute into the encoding attribute, and finally
423+
// converts the RankedTensorType to VectorType based on the encoding.
424+
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
425+
426+
MLIRContext *ctx = &getContext();
427+
RewritePatternSet patterns(ctx);
428+
ConversionTarget target(*ctx);
429+
417430
converter.addConversion(
418431
[&](xegpu::TensorDescType type,
419432
SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
@@ -434,13 +447,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
434447
return success();
435448
});
436449

437-
// step1: perform SCFStructuralTypeConversions on SCF ops
438-
xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
439-
440-
MLIRContext *ctx = &getContext();
441-
RewritePatternSet patterns(ctx);
442-
ConversionTarget target(*ctx);
443-
444450
auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
445451
if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
446452
return createOp.getType();
@@ -476,7 +482,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
476482

477483
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
478484

479-
// step2: Perform for workgroup to subgroup distribution for rest ops
485+
// Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
486+
// as well as XeGPU, Arith, and Vector operations.
487+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
488+
target);
480489
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
481490
if (failed(
482491
applyPartialConversion(getOperation(), target, std::move(patterns))))

0 commit comments

Comments
 (0)