1313#include " mlir/Dialect/Index/IR/IndexDialect.h"
1414#include " mlir/Dialect/Index/IR/IndexOps.h"
1515#include " mlir/Dialect/MemRef/IR/MemRef.h"
16+ #include " mlir/Dialect/SCF/Transforms/Patterns.h"
1617#include " mlir/Dialect/Utils/IndexingUtils.h"
1718#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1819#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
@@ -329,13 +330,12 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
329330
330331// Handles UnrealizedConversionCastOp generated during
331332// SCFStructuralTypeConversions (step 1). This op may appear as either a
332- // target or source materialization for Vector or TensorDesc values, e.g.:
333+ // target or source materialization for Vector values, e.g.:
333334// 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
334335// 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
335- // it could be either 1:1, 1: N or N:1 cast. In all cases, the pattern
336+ // it could be either 1:N or N:1 cast. In both cases, the pattern
336337// simply forwards the inputs to the outputs using 1:1 or 1:N interface.
337338// TODO: remove it when context-aware type converter is ready.
338- // It is safe only when input codes don't contain UnrealizedConversionCastOp.
339339struct UnrealizedConversionCastOpPattern
340340 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
341341 using OpConversionPattern<
@@ -346,28 +346,30 @@ struct UnrealizedConversionCastOpPattern
346346 ConversionPatternRewriter &rewriter) const override {
347347 SmallVector<Value> inputs = xegpu::flattenValues (adaptor.getInputs ());
348348
349- auto inputTy = inputs[0 ].getType ();
350- auto outputTy = op->getOpResult (0 ).getType ();
349+ auto inputTy = dyn_cast<VectorType>( inputs[0 ].getType () );
350+ auto outputTy = dyn_cast<VectorType>( op->getOpResult (0 ).getType () );
351351
352- if (!llvm::all_equal (op->getResultTypes ()) ||
353- !llvm::all_equal (ValueRange (inputs).getTypes ()) ||
354- !isa<VectorType, xegpu::TensorDescType>(inputTy) ||
355- !isa<VectorType, xegpu::TensorDescType>(outputTy))
352+ if (!inputTy || !outputTy || !llvm::all_equal (op->getResultTypes ()) ||
353+ !llvm::all_equal (ValueRange (inputs).getTypes ()))
356354 return failure ();
357355
358356 // Handles the case where cast %1 : vector<256xf32> to vector<16xf32>, ...
359357 // the input values provided by the adaptor should already be distributed,
360358 // and their types should correspond exactly to the result types of the
361359 // operation.
362- if (op.getNumOperands () == 1 ) {
360+ if (op.getNumOperands () == 1 &&
361+ llvm::equal (ValueRange (inputs).getTypes (), op->getResultTypes ())) {
363362 rewriter.replaceOp (op, inputs);
364363 return success ();
365364 }
366365
367366 // Handles the case where cast %1 : vector<16xf32>, ... to vector<256xf32>.
368367 // All input values must have the same vector type, and their shape must be
369368 // evenly divisible by the output vector's shape.
370- if (op.getNumResults () == 1 ) {
369+ // TODO: it is not safe to do such forward, since such N:1 cast could be
370+ // from others
371+ if (op.getNumResults () == 1 &&
372+ computeShapeRatio (outputTy.getShape (), inputTy.getShape ())) {
371373 rewriter.replaceOpWithMultiple (op, {inputs});
372374 return success ();
373375 }
@@ -396,6 +398,7 @@ struct XeGPUWgToSgDistributePass
396398} // namespace
397399
398400void XeGPUWgToSgDistributePass::runOnOperation () {
401+
399402 TypeConverter converter;
400403 converter.addConversion ([&](Type type) -> Type { return type; });
401404 converter.addConversion (
@@ -414,6 +417,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
414417 return success ();
415418 });
416419
420+ // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
421+ // VectorType operands. This first converts such operands to RankedTensorType,
422+ // propagates the layout attribute into the encoding attribute, and finally
423+ // converts the RankedTensorType to VectorType based on the encoding.
424+ xegpu::doSCFStructuralTypeConversionWithTensorType (getOperation (), converter);
425+
426+ MLIRContext *ctx = &getContext ();
427+ RewritePatternSet patterns (ctx);
428+ ConversionTarget target (*ctx);
429+
417430 converter.addConversion (
418431 [&](xegpu::TensorDescType type,
419432 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
@@ -434,13 +447,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
434447 return success ();
435448 });
436449
437- // step1: perform SCFStructuralTypeConversions on SCF ops
438- xegpu::doSCFStructuralTypeConversionWithTensorType (getOperation (), converter);
439-
440- MLIRContext *ctx = &getContext ();
441- RewritePatternSet patterns (ctx);
442- ConversionTarget target (*ctx);
443-
444450 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
445451 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
446452 return createOp.getType ();
@@ -476,7 +482,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
476482
477483 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
478484
479- // step2: Perform for workgroup to subgroup distribution for rest ops
485+ // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
486+ // as well as XeGPU, Arith, and Vector operations.
487+ scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns,
488+ target);
480489 xegpu::populateXeGPUWgToSgDistributePatterns (patterns);
481490 if (failed (
482491 applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments