1313#include " mlir/Dialect/Index/IR/IndexDialect.h"
1414#include " mlir/Dialect/Index/IR/IndexOps.h"
1515#include " mlir/Dialect/MemRef/IR/MemRef.h"
16+ #include " mlir/Dialect/SCF/Transforms/Patterns.h"
1617#include " mlir/Dialect/Utils/IndexingUtils.h"
1718#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1819#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
20+ #include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1921#include " mlir/Transforms/DialectConversion.h"
2022
2123namespace mlir {
@@ -29,6 +31,29 @@ using namespace mlir;
2931
3032namespace {
3133
34+ static std::pair<SmallVector<int64_t >, int >
35+ getSgShapeAndCount (ArrayRef<int64_t > shape, xegpu::LayoutAttr layout) {
36+ int count = 1 ;
37+ SmallVector<int64_t > sgShape (shape);
38+
39+ if (layout && layout.isWgLayout ()) {
40+ DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout ();
41+ auto sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
42+ if (DenseI32ArrayAttr sgDataAttr = layout.getSgData ())
43+ sgShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
44+ else
45+ sgShape = computeShapeRatio (shape, sgLayout).value_or (sgShape);
46+ SmallVector<int64_t > distUnit = computeElementwiseMul (sgLayout, sgShape);
47+ // Clamp distUnit to the original shape to handle cases where data is
48+ // shared among subgroups, which may cause distUnit to exceed the original
49+ // shape.
50+ for (size_t i = 0 ; i < distUnit.size (); ++i)
51+ distUnit[i] = std::min (shape[i], distUnit[i]);
52+ count = computeProduct (shape) / computeProduct (distUnit);
53+ }
54+ return std::make_pair (sgShape, count);
55+ }
56+
3257// / This pattern transforms the CreateNdDescOp to create a subgroup descriptor
3358// / from a workgroup descriptor. It replaces the offsets and sizes with
3459// / appropriate values for the subgroup.
@@ -129,18 +154,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
129154 return rewriter.notifyMatchFailure (
130155 op, " sgLayout attribute is required in layout" );
131156
132- SmallVector<int64_t > sgShape;
133- if (auto sgDataAttr = layout.getSgData ()) {
134- sgShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
135- } else {
136- assert (wgShape.size () == sgLayout.size () &&
137- " sgLayout and wgShape must have the same rank" );
138- sgShape.reserve (wgShape.size ());
139- for (size_t i = 0 ; i < wgShape.size (); ++i) {
140- assert (sgLayout[i] != 0 && " sgLayout elements must be non-zero" );
141- sgShape.push_back (wgShape[i] / sgLayout[i]);
142- }
143- }
157+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
144158
145159 // TODO : Handle order attribute
146160 // Get the subgroup ID
@@ -266,15 +280,15 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
266280 if (resultTy.getRank () != 2 )
267281 return failure ();
268282
269- auto originalLayout =
270- llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr (" layout" ));
283+ auto originalLayout = xegpu::getLayoutAttr (op.getResult ());
271284 if (!originalLayout)
272285 return failure ();
273286
274- SmallVector<Value> newDpasOps;
275287 size_t i = 0 ;
288+ SmallVector<Value> newDpasOps;
276289 for (auto aVec : adaptor.getLhs ()) {
277290 for (auto bVec : adaptor.getRhs ()) {
291+
278292 llvm::SmallVector<Value> operands ({aVec, bVec});
279293 Value tmpC;
280294 if (op.getAcc ()) {
@@ -288,10 +302,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
288302 llvm::cast<VectorType>(bVec.getType ()).getShape ();
289303 VectorType resTy = VectorType::get ({aVecShape[0 ], bVecShape[1 ]},
290304 resultTy.getElementType ());
291- tmpC = rewriter.create <xegpu::DpasOp>(
292- loc, resTy, operands ,
293- llvm::ArrayRef<NamedAttribute>(
294- { " layout_result_0 " , originalLayout. dropSgLayoutAndData ()}));
305+ tmpC = rewriter.create <xegpu::DpasOp>(loc, resTy, operands);
306+ xegpu::setLayoutAttr (cast<OpResult>(tmpC) ,
307+ originalLayout. dropSgLayoutAndData ());
308+
295309 newDpasOps.push_back (tmpC);
296310 }
297311 }
@@ -314,14 +328,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
314328 }
315329};
316330
331+ // Handles UnrealizedConversionCastOp generated during
332+ // SCFStructuralTypeConversions (step 1). This op may appear as either a
333+ // target or source materialization for Vector values, e.g.:
334+ // 1. unrealized_cast %1 : vector<256xf32> to vector<16xf32>, ...
335+ // 2. unrealized_cast %1 : vector<16xf32>, ... to vector<256xf32>
336+ // it could be either 1:N or N:1 cast. In both cases, the pattern
337+ // simply forwards the inputs to the outputs using 1:1 or 1:N interface.
338+ // for example, the following scf::forOp
339+ // ```
340+ // %for = scf.for ... iter_args(%arg1 = %0)->(vector<128x128xf16>) {
341+ // %n = use(%arg1): vector<128x128xf16>
342+ // scf.yield %n : vector<128x128xf16>
343+ // }
344+ // ```
345+ // Could be converted to:
346+ // ```
347+ // %1 = unrealized_conversion_cast %0
348+ // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
349+ // %for:2 = scf.for ... iter_args(%arg1 = %1#1, %arg2 = %1#2)
350+ // -> (vector<16x16xf16>, vector<16x16xf16) {
351+ // %m = unrealized_conversion_cast %arg1, %arg2
352+ // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
353+ // %n = use(%m): vector<128x128xf16>
354+ // %b = unrealized_conversion_cast %n
355+ // : vector<128x128xf16> to vector<16x16xf16>, vector<16x16xf16>
356+ // scf.yield %b#1, %b#2 : vector<16x16xf16>, vector<16x16xf16>
357+ // }
358+ // %cast = unrealized_conversion_cast %for:2
359+ // : vector<16x16xf16>, vector<16x16xf16> to vector<128x128xf16>
360+ // ```
361+ // TODO: remove it when context-aware type converter is ready.
362+ struct UnrealizedConversionCastOpPattern
363+ : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
364+ using OpConversionPattern<
365+ mlir::UnrealizedConversionCastOp>::OpConversionPattern;
366+
367+ mlir::LogicalResult
368+ matchAndRewrite (mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
369+ ConversionPatternRewriter &rewriter) const override {
370+ SmallVector<Value> inputs = xegpu::flattenValues (adaptor.getInputs ());
371+
372+ auto inputTy = dyn_cast<VectorType>(inputs[0 ].getType ());
373+ auto outputTy = dyn_cast<VectorType>(op->getOpResult (0 ).getType ());
374+
375+ if (!inputTy || !outputTy || !llvm::all_equal (op->getResultTypes ()) ||
376+ !llvm::all_equal (ValueRange (inputs).getTypes ()))
377+ return failure ();
378+
379+ // Handles the case "cast %1 : vector<256xf32> to vector<16xf32>, ...".
380+ // It is generated by source materialization (e.g., inits to scf forOp).
381+ // The input values provided by the adaptor should already be distributed,
382+ // and their types should correspond exactly to the result types of the
383+ // operation.
384+ if (op.getNumOperands () == 1 &&
385+ llvm::equal (ValueRange (inputs).getTypes (), op->getResultTypes ())) {
386+ rewriter.replaceOp (op, inputs);
387+ return success ();
388+ }
389+
390+ // Handles the case "cast %1 : vector<16xf32>, ... to vector<256xf32>".
391+ // It is generated by target materialization (e.g., arguments/results
392+ // of scf forOp). All input values must have the same vector type, and
393+ // their shape must be evenly divisible by the output vector's shape
394+ // (determined by the nature of the workgroup to subgroup distribution).
395+ // TODO: it is not safe to do such forward, since such N:1 cast could be
396+ // from others.
397+ if (op.getNumResults () == 1 &&
398+ computeShapeRatio (outputTy.getShape (), inputTy.getShape ())) {
399+ rewriter.replaceOpWithMultiple (op, {inputs});
400+ return success ();
401+ }
402+
403+ return mlir::failure ();
404+ }
405+ };
406+
317407} // namespace
318408
319409namespace mlir {
320410namespace xegpu {
321411void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
322412 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324- patterns.getContext ());
413+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
414+ UnrealizedConversionCastOpPattern>( patterns.getContext ());
325415}
326416} // namespace xegpu
327417} // namespace mlir
@@ -334,9 +424,68 @@ struct XeGPUWgToSgDistributePass
334424} // namespace
335425
336426void XeGPUWgToSgDistributePass::runOnOperation () {
427+ // Track existing UnrealizedConversionCastOps
428+ SmallVector<Operation *> existingCastOps;
429+ getOperation ()->walk ([&](UnrealizedConversionCastOp castOp) {
430+ existingCastOps.push_back (castOp.getOperation ());
431+ });
432+
433+ {
434+ // Step 1: Apply SCFStructuralTypeConversions to SCF operations with
435+ // VectorType operands. This first converts such operands to
436+ // RankedTensorType, propagates the layout attribute into the encoding
437+ // attribute, and finally converts the RankedTensorType to VectorType based
438+ // on the encoding.
439+
440+ TypeConverter converter;
441+ converter.addConversion ([&](Type type) -> Type { return type; });
442+ converter.addConversion (
443+ [&](RankedTensorType type,
444+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
445+ Type elemTy = type.getElementType ();
446+ ArrayRef<int64_t > shape = type.getShape ();
447+
448+ int count;
449+ SmallVector<int64_t > subShape;
450+ std::tie (subShape, count) = getSgShapeAndCount (
451+ shape,
452+ dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding ()));
453+
454+ auto newTy = VectorType::get (subShape, elemTy);
455+ result.append (count, newTy);
456+ return success ();
457+ });
458+
459+ xegpu::doSCFStructuralTypeConversionWithTensorType (getOperation (),
460+ converter);
461+ }
462+
463+ // Step 2: Perform workgroup to subgroup distribution for TensorDesc values,
464+ // as well as XeGPU, Arith, and Vector operations.
337465 MLIRContext *ctx = &getContext ();
338466 RewritePatternSet patterns (ctx);
339467 ConversionTarget target (*ctx);
468+ TypeConverter converter;
469+ converter.addConversion ([&](Type type) -> Type { return type; });
470+ converter.addConversion (
471+ [&](xegpu::TensorDescType type,
472+ SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
473+ Type elemTy = type.getElementType ();
474+ ArrayRef<int64_t > shape = type.getShape ();
475+
476+ int count;
477+ SmallVector<int64_t > subShape;
478+ xegpu::LayoutAttr layout = type.getLayoutAttr ();
479+ std::tie (subShape, count) = getSgShapeAndCount (shape, layout);
480+
481+ if (layout)
482+ layout = layout.dropSgLayoutAndData ();
483+
484+ auto newTy = xegpu::TensorDescType::get (
485+ type.getContext (), subShape, elemTy, type.getEncoding (), layout);
486+ result.append (count, newTy);
487+ return success ();
488+ });
340489
341490 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
342491 if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
@@ -353,26 +502,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
353502 };
354503
355504 auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
356- return !layout || layout.getSgLayout () == nullptr ;
505+ return !layout || ! layout.isWgLayout () ;
357506 };
358507
359508 target.addDynamicallyLegalOp <xegpu::CreateNdDescOp, xegpu::LoadNdOp,
360509 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
361510 xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
362511 auto tdescTy = getTensorDescType (op);
363- auto layout = dyn_cast_or_null <xegpu::LayoutAttr>(tdescTy.getLayout ());
512+ auto layout = dyn_cast_if_present <xegpu::LayoutAttr>(tdescTy.getLayout ());
364513 return isLegal (layout);
365514 });
366515
367516 target.addDynamicallyLegalOp <xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
368- auto layout = dyn_cast_or_null< xegpu::LayoutAttr> (op-> getAttr ( " layout " ));
517+ auto layout = xegpu::getLayoutAttr (op. getResult ( ));
369518 return isLegal (layout);
370519 });
371520
521+ target.addDynamicallyLegalOp <UnrealizedConversionCastOp>(
522+ [=](UnrealizedConversionCastOp op) {
523+ return llvm::is_contained (existingCastOps, op.getOperation ());
524+ });
525+
372526 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
373527
528+ scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns,
529+ target);
374530 xegpu::populateXeGPUWgToSgDistributePatterns (patterns);
375531 if (failed (
376532 applyPartialConversion (getOperation (), target, std::move (patterns))))
377533 return signalPassFailure ();
534+
535+ // Remove sg_layout and sg_data attributes from the Layout
536+ // attribute for each VectorType result of the operation.
537+ // For Structured Control Flow ops, the layout is simply removed,
538+ // since in 1:N case, the layout for new results are missing.
539+ // Layout propagation pass will activated.
540+ getOperation ()->walk ([](Operation *op) {
541+ for (OpResult result : op->getOpResults ()) {
542+ std::string name = xegpu::getLayoutName (result);
543+ if (auto layout = op->getAttrOfType <xegpu::LayoutAttr>(name)) {
544+ op->removeAttr (name);
545+ if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
546+ op->setAttr (name, layout.dropSgLayoutAndData ());
547+ }
548+ }
549+ });
378550}
0 commit comments