File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
mlir/lib/Dialect/XeGPU/Transforms Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -785,6 +785,14 @@ struct WgToSgLoadGatherOpWithOffset
785785
786786 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
787787
788+ // The offsets need to be distributed
789+ if (dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ())
790+ .getShape () !=
791+ dyn_cast<VectorType>(adaptor.getMask ().front ().getType ()).getShape ()) {
792+ return rewriter.notifyMatchFailure (op,
793+ " offsets have not been distributed" );
794+ }
795+
788796 SmallVector<Value> newLoadOps;
789797 auto chunkSizeAttr =
790798 rewriter.getI64IntegerAttr (op.getChunkSize ().value_or (1 ));
@@ -824,6 +832,14 @@ struct WgToSgStoreScatterOpWithOffset
824832 if (!layout || !layout.isForWorkgroup ())
825833 return failure ();
826834
835+ // The offsets need to be distributed
836+ if (dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ())
837+ .getShape () !=
838+ dyn_cast<VectorType>(adaptor.getMask ().front ().getType ()).getShape ()) {
839+ return rewriter.notifyMatchFailure (op,
840+ " offsets have not been distributed" );
841+ }
842+
827843 auto chunkSizeOpt = op.getChunkSize ();
828844 int64_t chunkSize = chunkSizeOpt ? static_cast <int64_t >(*chunkSizeOpt) : 1 ;
829845 auto chunkSizeAttr = rewriter.getI64IntegerAttr (chunkSize);
You can’t perform that action at this time.
0 commit comments