88#include " mlir/Dialect/XeGPU/Transforms/Passes.h"
99
1010#include " mlir/Dialect/Affine/Utils.h"
11+ #include " mlir/Dialect/Arith/IR/Arith.h"
1112#include " mlir/Dialect/Arith/Utils/Utils.h"
1213#include " mlir/Dialect/GPU/IR/GPUDialect.h"
1314#include " mlir/Dialect/Index/IR/IndexDialect.h"
1415#include " mlir/Dialect/Index/IR/IndexOps.h"
16+ #include " mlir/Dialect/Math/IR/Math.h"
1517#include " mlir/Dialect/MemRef/IR/MemRef.h"
1618#include " mlir/Dialect/Utils/IndexingUtils.h"
1719#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1820#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
1921#include " mlir/Transforms/DialectConversion.h"
22+ #include < optional>
2023
2124namespace mlir {
2225namespace xegpu {
@@ -314,6 +317,179 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
314317 }
315318};
316319
320+ // This pattern matches elementwise ops (unary/binary) in math/arith dialects
321+ // with 1D or 2D vector types
322+ template <typename Op>
323+ struct WgToSgElementwiseOp : public OpConversionPattern <Op> {
324+ using OpConversionPattern<Op>::OpConversionPattern;
325+ using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
326+
327+ LogicalResult
328+ matchAndRewrite (Op op, OneToNOpAdaptor adaptor,
329+ ConversionPatternRewriter &rewriter) const override {
330+ // All operands/results must be 1D or 2D vectors
331+ auto resultType = dyn_cast<VectorType>(op.getResult ().getType ());
332+ if (!resultType || (resultType.getRank () != 1 && resultType.getRank () != 2 ))
333+ return rewriter.notifyMatchFailure (
334+ op, " Result type is not a 1D or 2D vector" );
335+
336+ ArrayRef<int64_t > shape = resultType.getShape ();
337+ for (Value operand : op->getOperands ()) {
338+ auto operandType = dyn_cast<VectorType>(operand.getType ());
339+ if (!operandType || operandType.getRank () != resultType.getRank () ||
340+ operandType.getShape () != shape) {
341+ return rewriter.notifyMatchFailure (
342+ op, " Operand type is not a 1D or 2D vector with the same shape as "
343+ " result type" );
344+ }
345+ }
346+
347+ // Check for layout attribute with sgLayout
348+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr (" layout" ));
349+ if (!layout || !layout.getSgLayout ())
350+ return rewriter.notifyMatchFailure (
351+ op, " Operation does not have a valid layout attribute for subgroup "
352+ " distribution" );
353+
354+ // Extract sgShape from layout
355+ SmallVector<int64_t > sgShape;
356+ if (auto sgDataAttr = layout.getSgData ()) {
357+ sgShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
358+ } else {
359+ auto sgLayoutArr = layout.getSgLayout ();
360+ sgShape.reserve (shape.size ());
361+ for (size_t i = 0 ; i < shape.size (); ++i) {
362+ assert (sgLayoutArr[i] != 0 && " sgLayout elements must be non-zero" );
363+ sgShape.push_back (shape[i] / sgLayoutArr[i]);
364+ }
365+ }
366+
367+ // Each operand is a list of values
368+ size_t numVariants = adaptor.getOperands ().empty ()
369+ ? 0
370+ : adaptor.getOperands ().front ().size ();
371+ for (auto &operandVec : adaptor.getOperands ())
372+ if (operandVec.size () != numVariants)
373+ return rewriter.notifyMatchFailure (
374+ op, " Operand lists have mismatched sizes" );
375+
376+ SmallVector<Value> newResults;
377+
378+ auto origResultType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
379+ VectorType newResultType =
380+ origResultType
381+ ? VectorType::get (sgShape, origResultType.getElementType ())
382+ : VectorType::get (sgShape, resultType.getElementType ());
383+
384+ for (size_t i = 0 ; i < numVariants; ++i) {
385+ SmallVector<Value> operands;
386+ for (auto &operandVec : adaptor.getOperands ())
387+ operands.push_back (operandVec[i]);
388+
389+ auto newOp = rewriter.create <Op>(op.getLoc (), newResultType, operands);
390+
391+ // Copy all attributes except "layout", and add "layout_result_0" with
392+ // sgLayout/data dropped
393+ for (auto attr : op->getAttrs ()) {
394+ if (attr.getName () != " layout" )
395+ newOp->setAttr (attr.getName (), attr.getValue ());
396+ }
397+ newOp->setAttr (" layout_result_0" , layout.dropSgLayoutAndData ());
398+
399+ newResults.push_back (newOp.getResult ());
400+ }
401+
402+ rewriter.replaceOpWithMultiple (op, {newResults});
403+ return success ();
404+ }
405+ };
406+
407+ // ---- ARITH ops ----
408+ using WgToSgAddFOp = WgToSgElementwiseOp<arith::AddFOp>;
409+ using WgToSgSubFOp = WgToSgElementwiseOp<arith::SubFOp>;
410+ using WgToSgNegFOp = WgToSgElementwiseOp<arith::NegFOp>;
411+ using WgToSgAddIOp = WgToSgElementwiseOp<arith::AddIOp>;
412+ using WgToSgSubIOp = WgToSgElementwiseOp<arith::SubIOp>;
413+ using WgToSgMulFOp = WgToSgElementwiseOp<arith::MulFOp>;
414+ using WgToSgMulIOp = WgToSgElementwiseOp<arith::MulIOp>;
415+ using WgToSgShLIOp = WgToSgElementwiseOp<arith::ShLIOp>;
416+ using WgToSgShRSIOp = WgToSgElementwiseOp<arith::ShRSIOp>;
417+ using WgToSgShRUIOp = WgToSgElementwiseOp<arith::ShRUIOp>;
418+ using WgToSgDivFOp = WgToSgElementwiseOp<arith::DivFOp>;
419+ using WgToSgDivSIOp = WgToSgElementwiseOp<arith::DivSIOp>;
420+ using WgToSgDivUIOp = WgToSgElementwiseOp<arith::DivUIOp>;
421+ using WgToSgMaximumFOp = WgToSgElementwiseOp<arith::MaximumFOp>;
422+ using WgToSgMinimumFOp = WgToSgElementwiseOp<arith::MinimumFOp>;
423+ using WgToSgRemSIOp = WgToSgElementwiseOp<arith::RemSIOp>;
424+ using WgToSgRemUIOp = WgToSgElementwiseOp<arith::RemUIOp>;
425+ using WgToSgTruncFOp = WgToSgElementwiseOp<arith::TruncFOp>;
426+ using WgToSgTruncIOp = WgToSgElementwiseOp<arith::TruncIOp>;
427+ using WgToSgExtFOp = WgToSgElementwiseOp<arith::ExtFOp>;
428+ using WgToSgExtSIOp = WgToSgElementwiseOp<arith::ExtSIOp>;
429+ using WgToSgExtUIOp = WgToSgElementwiseOp<arith::ExtUIOp>;
430+ using WgToSgSIToFPOp = WgToSgElementwiseOp<arith::SIToFPOp>;
431+ using WgToSgUIToFPOp = WgToSgElementwiseOp<arith::UIToFPOp>;
432+ using WgToSgFPToSIOp = WgToSgElementwiseOp<arith::FPToSIOp>;
433+ using WgToSgFPToUIOp = WgToSgElementwiseOp<arith::FPToUIOp>;
434+ using WgToSgIndexCastUIOp = WgToSgElementwiseOp<arith::IndexCastUIOp>;
435+ using WgToSgIndexCastOp = WgToSgElementwiseOp<arith::IndexCastOp>;
436+ using WgToSgBitcastOp = WgToSgElementwiseOp<arith::BitcastOp>;
437+ using WgToSgCmpIOp = WgToSgElementwiseOp<arith::CmpIOp>;
438+ using WgToSgCmpFOp = WgToSgElementwiseOp<arith::CmpFOp>;
439+ using WgToSgAndIOp = WgToSgElementwiseOp<arith::AndIOp>;
440+ using WgToSgCeilDivSIOp = WgToSgElementwiseOp<arith::CeilDivSIOp>;
441+ using WgToSgCeilDivUIOp = WgToSgElementwiseOp<arith::CeilDivUIOp>;
442+ using WgToSgFloorDivSIOp = WgToSgElementwiseOp<arith::FloorDivSIOp>;
443+ using WgToSgMaxNumFOp = WgToSgElementwiseOp<arith::MaxNumFOp>;
444+ using WgToSgMaxSIOp = WgToSgElementwiseOp<arith::MaxSIOp>;
445+ using WgToSgMaxUIOp = WgToSgElementwiseOp<arith::MaxUIOp>;
446+ using WgToSgMinNumFOp = WgToSgElementwiseOp<arith::MinNumFOp>;
447+ using WgToSgMinSIOp = WgToSgElementwiseOp<arith::MinSIOp>;
448+ using WgToSgMinUIOp = WgToSgElementwiseOp<arith::MinUIOp>;
449+ using WgToSgOrIOp = WgToSgElementwiseOp<arith::OrIOp>;
450+ using WgToSgRemFOp = WgToSgElementwiseOp<arith::RemFOp>;
451+ using WgToSgSelectOp = WgToSgElementwiseOp<arith::SelectOp>;
452+ using WgToSgXOrIOp = WgToSgElementwiseOp<arith::XOrIOp>;
453+
454+ // ---- MATH ops ----
455+ using WgToSgExpOp = WgToSgElementwiseOp<math::ExpOp>;
456+ using WgToSgSqrtOp = WgToSgElementwiseOp<math::SqrtOp>;
457+ using WgToSgAbsFOp = WgToSgElementwiseOp<math::AbsFOp>;
458+ using WgToSgCosOp = WgToSgElementwiseOp<math::CosOp>;
459+ using WgToSgCoshOp = WgToSgElementwiseOp<math::CoshOp>;
460+ using WgToSgAcosOp = WgToSgElementwiseOp<math::AcosOp>;
461+ using WgToSgAcoshOp = WgToSgElementwiseOp<math::AcoshOp>;
462+ using WgToSgSinOp = WgToSgElementwiseOp<math::SinOp>;
463+ using WgToSgSinhOp = WgToSgElementwiseOp<math::SinhOp>;
464+ using WgToSgAsinOp = WgToSgElementwiseOp<math::AsinOp>;
465+ using WgToSgAsinhOp = WgToSgElementwiseOp<math::AsinhOp>;
466+ using WgToSgTanOp = WgToSgElementwiseOp<math::TanOp>;
467+ using WgToSgTanhOp = WgToSgElementwiseOp<math::TanhOp>;
468+ using WgToSgAtanOp = WgToSgElementwiseOp<math::AtanOp>;
469+ using WgToSgAtan2Op = WgToSgElementwiseOp<math::Atan2Op>;
470+ using WgToSgAtanhOp = WgToSgElementwiseOp<math::AtanhOp>;
471+ using WgToSgErfOp = WgToSgElementwiseOp<math::ErfOp>;
472+ using WgToSgLogOp = WgToSgElementwiseOp<math::LogOp>;
473+ using WgToSgLog2Op = WgToSgElementwiseOp<math::Log2Op>;
474+ using WgToSgFloorOp = WgToSgElementwiseOp<math::FloorOp>;
475+ using WgToSgCeilOp = WgToSgElementwiseOp<math::CeilOp>;
476+ using WgToSgPowFOp = WgToSgElementwiseOp<math::PowFOp>;
477+ using WgToSgRsqrtOp = WgToSgElementwiseOp<math::RsqrtOp>;
478+ using WgToSgAbsIOp = WgToSgElementwiseOp<math::AbsIOp>;
479+ using WgToSgCbrtOp = WgToSgElementwiseOp<math::CbrtOp>;
480+ using WgToSgCopySignOp = WgToSgElementwiseOp<math::CopySignOp>;
481+ using WgToSgCtPopOp = WgToSgElementwiseOp<math::CtPopOp>;
482+ using WgToSgErfcOp = WgToSgElementwiseOp<math::ErfcOp>;
483+ using WgToSgExp2Op = WgToSgElementwiseOp<math::Exp2Op>;
484+ using WgToSgExpM1Op = WgToSgElementwiseOp<math::ExpM1Op>;
485+ using WgToSgFPowIOp = WgToSgElementwiseOp<math::FPowIOp>;
486+ using WgToSgIPowIOp = WgToSgElementwiseOp<math::IPowIOp>;
487+ using WgToSgLog10Op = WgToSgElementwiseOp<math::Log10Op>;
488+ using WgToSgLog1pOp = WgToSgElementwiseOp<math::Log1pOp>;
489+ using WgToSgRoundOp = WgToSgElementwiseOp<math::RoundOp>;
490+ using WgToSgRoundEvenOp = WgToSgElementwiseOp<math::RoundEvenOp>;
491+ using WgToSgTruncOp = WgToSgElementwiseOp<math::TruncOp>;
492+
317493} // namespace
318494
319495namespace mlir {
@@ -322,6 +498,27 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
322498 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323499 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324500 patterns.getContext ());
501+ // Add elementwise operations that can be distributed to subgroups
502+ patterns.add <
503+ WgToSgAddFOp, WgToSgSubFOp, WgToSgExpOp, WgToSgSqrtOp, WgToSgAbsFOp,
504+ WgToSgCosOp, WgToSgCoshOp, WgToSgAcosOp, WgToSgAcoshOp, WgToSgSinOp,
505+ WgToSgSinhOp, WgToSgAsinOp, WgToSgAsinhOp, WgToSgTanOp, WgToSgTanhOp,
506+ WgToSgAtanOp, WgToSgAtan2Op, WgToSgAtanhOp, WgToSgErfOp, WgToSgLogOp,
507+ WgToSgLog2Op, WgToSgFloorOp, WgToSgCeilOp, WgToSgPowFOp, WgToSgRsqrtOp,
508+ WgToSgNegFOp, WgToSgAddIOp, WgToSgSubIOp, WgToSgMulFOp, WgToSgMulIOp,
509+ WgToSgShLIOp, WgToSgShRSIOp, WgToSgShRUIOp, WgToSgDivFOp, WgToSgDivSIOp,
510+ WgToSgDivUIOp, WgToSgMaximumFOp, WgToSgMinimumFOp, WgToSgRemSIOp,
511+ WgToSgRemUIOp, WgToSgTruncFOp, WgToSgTruncIOp, WgToSgExtFOp,
512+ WgToSgExtSIOp, WgToSgExtUIOp, WgToSgSIToFPOp, WgToSgUIToFPOp,
513+ WgToSgFPToSIOp, WgToSgFPToUIOp, WgToSgIndexCastUIOp, WgToSgIndexCastOp,
514+ WgToSgBitcastOp, WgToSgCmpIOp, WgToSgCmpFOp, WgToSgAndIOp,
515+ WgToSgCeilDivSIOp, WgToSgCeilDivUIOp, WgToSgFloorDivSIOp, WgToSgMaxNumFOp,
516+ WgToSgMaxSIOp, WgToSgMaxUIOp, WgToSgMinNumFOp, WgToSgMinSIOp,
517+ WgToSgMinUIOp, WgToSgOrIOp, WgToSgRemFOp, WgToSgSelectOp, WgToSgXOrIOp,
518+ WgToSgAbsIOp, WgToSgCbrtOp, WgToSgCopySignOp, WgToSgCtPopOp, WgToSgErfcOp,
519+ WgToSgExp2Op, WgToSgExpM1Op, WgToSgFPowIOp, WgToSgIPowIOp, WgToSgLog10Op,
520+ WgToSgLog1pOp, WgToSgRoundOp, WgToSgRoundEvenOp, WgToSgTruncOp>(
521+ patterns.getContext ());
325522}
326523} // namespace xegpu
327524} // namespace mlir
@@ -368,6 +565,32 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
368565 auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr (" layout" ));
369566 return isLegal (layout);
370567 });
568+ target.addDynamicallyLegalDialect <math::MathDialect, arith::ArithDialect>(
569+ [=](Operation *op) -> std::optional<bool > {
570+ // Handle unary and binary operations
571+ if (op->getNumOperands () < 1 || op->getNumOperands () > 2 )
572+ return true ;
573+
574+ // check if input and output are vectors
575+ VectorType resultType =
576+ dyn_cast<VectorType>(op->getResult (0 ).getType ());
577+ if (!resultType || resultType.getRank () != 2 )
578+ return true ;
579+
580+ // Check if all operands are vectors
581+ for (Value operand : op->getOperands ()) {
582+ VectorType operandType = dyn_cast<VectorType>(operand.getType ());
583+ if (!operandType || operandType.getRank () != 2 ||
584+ operandType.getShape () != resultType.getShape ()) {
585+ return true ;
586+ }
587+ }
588+
589+ // check layout attribute
590+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
591+ op->getAttrOfType <xegpu::LayoutAttr>(" layout" ));
592+ return isLegal (layout);
593+ });
371594
372595 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
373596
0 commit comments