@@ -106,12 +106,12 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
106106 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
107107
108108 // Calculate offset for each subgroup
109- SmallVector<OpFoldResult>
109+ static SmallVector<OpFoldResult>
110110 calculateGlobalOffsets (ConversionPatternRewriter &rewriter, Location loc,
111111 const SmallVector<OpFoldResult> &originalOffsets,
112112 const SmallVector<Value> &localOffset,
113113 const SmallVector<int64_t > &distUnitBaseAddr,
114- const SmallVector<int64_t > &distUnitShape) const {
114+ const SmallVector<int64_t > &distUnitShape) {
115115 assert (localOffset.size () == distUnitBaseAddr.size () &&
116116 " localOffset and distUnitBaseAddr must have the same rank" );
117117
@@ -466,6 +466,75 @@ struct WgToSgElementwiseOp : public ConversionPattern {
466466 }
467467};
468468
469+ // clang-format off
470+ // Pattern for lowering ConvertLayoutOp based on sg_layout and sg_data.
471+ // If input_layout and target_layout have identical sg_layout and sg_data,
472+ // the op is rewritten to a subgroup-level ConvertLayoutOp with these fields
473+ // dropped. For example:
474+ // #a = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>
475+ // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>
476+ // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
477+ // becomes:
478+ // #a = #xegpu.layout<inst_data = [16, 16]>
479+ // #b = #xegpu.layout<inst_data = [8, 16]>
480+ // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<16x16xf32>
481+ // (vector<16x16xf32> is determined by sg_data = [16, 16])
482+ //
483+ // If sg_layout or sg_data differ, SLM is used to redistribute data across subgroups.
484+ // For example:
485+ // #a = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 16], inst_data = [16, 16]>
486+ // #b = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 32], inst_data = [8, 16]>
487+ // xegpu.convert_layout %1 <{input_layout = #a, target_layout = #b}> : vector<32x64xf32>
488+ // is lowered to:
489+ // #a = #xegpu.layout<inst_data = [16, 16]>
490+ // #b = #xegpu.layout<inst_data = [8, 16]>
491+ // store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
492+ // %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
493+ // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
494+ // clang-format on
495+ struct WgToSgConvertLayoutOp
496+ : public OpConversionPattern<xegpu::ConvertLayoutOp> {
497+ using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
498+ LogicalResult
499+ matchAndRewrite (xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
500+ ConversionPatternRewriter &rewriter) const override {
501+ xegpu::LayoutAttr input = op.getInputLayout ();
502+ xegpu::LayoutAttr target = op.getTargetLayout ();
503+
504+ if (!input || !target || !input.isWgLayout () || !target.isWgLayout ())
505+ return rewriter.notifyMatchFailure (
506+ op, " Input and target layouts must have subgroup layout" );
507+
508+ DenseI32ArrayAttr inputSgLayout = input.getSgLayout ();
509+ DenseI32ArrayAttr inputSgData = input.getSgData ();
510+ DenseI32ArrayAttr inputOrder = input.getOrder ();
511+ DenseI32ArrayAttr targetSgLayout = target.getSgLayout ();
512+ DenseI32ArrayAttr targetSgData = target.getSgData ();
513+ DenseI32ArrayAttr targetOrder = target.getOrder ();
514+
515+ // TODO: currently we only support for optimal case, where input and
516+ // output has the same sg_layout and sg_data, so SLM is not involved.
517+ if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
518+ inputOrder != targetOrder)
519+ return failure ();
520+
521+ input = input.dropSgLayoutAndData ();
522+ target = target.dropSgLayoutAndData ();
523+
524+ SmallVector<Value> newOps (adaptor.getSource ());
525+ if (input && target) {
526+ // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
527+ for (auto [i, src] : llvm::enumerate (adaptor.getSource ())) {
528+ auto newOp = rewriter.create <xegpu::ConvertLayoutOp>(
529+ op.getLoc (), src.getType (), src, input, target);
530+ newOps[i] = newOp;
531+ }
532+ }
533+ rewriter.replaceOpWithMultiple (op, {newOps});
534+ return success ();
535+ }
536+ };
537+
469538// Handles UnrealizedConversionCastOp generated during
470539// SCFStructuralTypeConversions (step 1). This op may appear as either a
471540// target or source materialization for Vector values, e.g.:
@@ -550,7 +619,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
550619 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
551620 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
552621 UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
553- WgToSgVectorBroadcastOp>(patterns.getContext ());
622+ WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
623+ patterns.getContext ());
554624}
555625} // namespace xegpu
556626} // namespace mlir
@@ -662,6 +732,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
662732 return isLegal (xegpu::getLayoutAttr (op.getResult ()));
663733 });
664734
735+ target.addDynamicallyLegalOp <xegpu::ConvertLayoutOp>(
736+ [=](xegpu::ConvertLayoutOp op) -> bool {
737+ return isLegal (op.getInputLayout ()) && isLegal (op.getTargetLayout ());
738+ });
739+
665740 target.addDynamicallyLegalDialect <math::MathDialect, arith::ArithDialect>(
666741 [=](Operation *op) -> std::optional<bool > {
667742 // Only handle elementwise mappable ops
0 commit comments