@@ -62,28 +62,6 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc,
6262 return res.getResult ();
6363}
6464
65- // Extracts the offsets from a subview operation as values.
66- // The differense from mlir::getMixedOffsets is that this function
67- // returns the offsets as mlir::Value that can already be used as an argument
68- // for other mlir::Operations.
69- static SmallVector<Value> extractOffsetsAsValues (PatternRewriter &rewriter,
70- Location loc,
71- memref::SubViewOp subview) {
72- SmallVector<Value> offsetValues;
73- auto staticOffsets = subview.getStaticOffsets ();
74- auto dynamicOffsets = subview.getOffsets ();
75- size_t dynIdx = 0 ;
76- for (size_t i = 0 ; i < staticOffsets.size (); i++) {
77- if (staticOffsets[i] == ShapedType::kDynamic )
78- offsetValues.push_back (dynamicOffsets[dynIdx++]);
79- else
80- offsetValues.push_back (
81- rewriter.create <arith::ConstantIndexOp>(loc, staticOffsets[i]));
82- }
83-
84- return offsetValues;
85- }
86-
8765// Max number of elements to load/store from SLM
8866constexpr int64_t maxSLMTileSize = 32 ;
8967
@@ -214,7 +192,8 @@ static LogicalResult isValidMemrefOperand(linalg::LinalgOp linalgOp,
214192 linalgOp, " Expect memref operand for XeGPU lowering" );
215193 }
216194
217- if (type.getShape ().size () > maxDims) {
195+ if (type.getShape ().size () > maxDims &&
196+ !utils::canSqueezeDims (type.getShape (), maxDims)) {
218197 return rewriter.notifyMatchFailure (
219198 linalgOp, " Too high dimensionality for XeGPU operations" );
220199 }
@@ -856,43 +835,33 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
856835 auto srcType = cast<MemRefType>(src.getType ());
857836 assert (srcType.getRank () == 2 && " Expected a 2D memref" );
858837
859- SmallVector<int64_t > memrefStrides;
860- Value blockOffset;
861-
838+ SmallVector<Value> offsets;
839+ Value rootMemref;
862840 // 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the
863841 // GPU kernel. We have to merge the subview offsets into the descriptor
864842 // offset.
865- if (auto subView = dyn_cast<memref::SubViewOp>(src.getDefiningOp ())) {
866- auto offsets = extractOffsetsAsValues (rewriter, loc, subView);
867- assert (offsets.size () == 2 && " Expected 2D subview offsets" );
868-
869- auto xIntOffs = offsets[0 ];
870- auto yIntOffs = offsets[1 ];
871-
872- // compute 'blockOffset' (beginning of the subview block in the original
873- // flat memref)
874- auto rowStride =
875- cast<MemRefType>(subView.getOperand (0 ).getType ()).getShape ()[1 ];
876- auto rowStrideValue =
877- rewriter.create <arith::ConstantIndexOp>(loc, rowStride);
878-
879- auto rowBlockOffset =
880- rewriter.create <arith::MulIOp>(loc, xIntOffs, rowStrideValue)
881- .getResult ();
882- blockOffset = rewriter.create <arith::AddIOp>(loc, rowBlockOffset, yIntOffs)
883- .getResult ();
843+ utils::computeSubviewOffsets (rewriter, loc, src, offsets, rootMemref);
844+ auto rootStridesFold = utils::getMemrefStrides (rewriter, loc, rootMemref);
845+ auto rootStrides =
846+ getValueOrCreateConstantIndexOp (rewriter, loc, rootStridesFold);
884847
885- memrefStrides = {rowStride, 1 };
886- src = subView.getOperand (0 );
887- } else {
888- // If the source is not a subview, then the blockOffset is 0
889- blockOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
890- memrefStrides = {srcType.getShape ()[1 ], 1 };
848+ assert (rootStrides.size () == offsets.size () &&
849+ " Expected same number of strides and offsets" );
850+
851+ // blockOffset = sum(rootStrides[i] * offsets[i])
852+ Value blockOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
853+ for (size_t i = 0 ; i < rootStrides.size (); i++) {
854+ auto mul = rewriter.create <arith::MulIOp>(loc, rootStrides[i], offsets[i]);
855+ blockOffset = rewriter.create <arith::AddIOp>(loc, blockOffset, mul);
891856 }
892857
893- // Scatter descriptors only work with 1D memrefs
894- src = utils::flattenMemref (rewriter, loc, src);
858+ auto memrefStridesFold = utils::getMemrefStrides (rewriter, loc, src);
859+ auto [memrefStrides, memrefStridesDynamic] =
860+ decomposeMixedValues (memrefStridesFold);
861+ assert (memrefStridesDynamic.size () == 0 &&
862+ " Expected all values to be resolved" );
895863
864+ src = utils::flattenMemref (rewriter, loc, rootMemref);
896865 return createScatterDescriptorTiles (
897866 rewriter, loc, /* flatMemref=*/ src, /* loadShape2D=*/ loadShape,
898867 /* tileSize2D=*/ descTile, /* memrefStrides=*/ memrefStrides,
@@ -1839,6 +1808,11 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
18391808 if (failed (isOutputValid))
18401809 return isOutputValid;
18411810
1811+ if (failed (mlir::utils::maybeSqueezeDims (rewriter, gemmLikeOp))) {
1812+ return rewriter.notifyMatchFailure (
1813+ gemmLikeOp, " Failed to squeeze dimensions of GEMM-like operation" );
1814+ }
1815+
18421816 // Ensure that reduction dimension tiling also works for smaller
18431817 // workloads.
18441818 auto aType = cast<ShapedType>(gemmLikeOp.getDpsInputs ()[0 ].getType ());
@@ -1894,6 +1868,12 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern<LinalgOpTy> {
18941868 if (failed (isOutputValid))
18951869 return isOutputValid;
18961870
1871+ if (failed (utils::maybeSqueezeDims (rewriter, eltwiseOp))) {
1872+ return rewriter.notifyMatchFailure (
1873+ eltwiseOp,
1874+ " Could not squeeze dimensions of the elementwise operation" );
1875+ }
1876+
18971877 return createEltwiseKernel (eltwiseOp, rewriter);
18981878 }
18991879
@@ -1988,6 +1968,12 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
19881968 if (failed (isOutputValid))
19891969 return isOutputValid;
19901970
1971+ if (failed (utils::maybeSqueezeDims (rewriter, linalgOp))) {
1972+ return rewriter.notifyMatchFailure (
1973+ linalgOp,
1974+ " Could not squeeze dimensions of the memory fill operation" );
1975+ }
1976+
19911977 return createMemoryFillKernel (linalgOp, rewriter);
19921978 }
19931979
0 commit comments