File tree Expand file tree Collapse file tree 1 file changed +12
-6
lines changed
mlir/lib/Dialect/XeGPU/Transforms Expand file tree Collapse file tree 1 file changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -786,9 +786,12 @@ struct WgToSgLoadGatherOpWithOffset
786786 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
787787
788788 // The offsets need to be distributed
789- if (dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ())
790- .getShape () !=
791- dyn_cast<VectorType>(adaptor.getMask ().front ().getType ()).getShape ()) {
789+ auto offsetsVecType =
790+ dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ());
791+ auto maskVecType =
792+ dyn_cast<VectorType>(adaptor.getMask ().front ().getType ());
793+ if (!offsetsVecType || !maskVecType ||
794+ offsetsVecType.getShape () != maskVecType.getShape ()) {
792795 return rewriter.notifyMatchFailure (op,
793796 " offsets have not been distributed" );
794797 }
@@ -833,9 +836,12 @@ struct WgToSgStoreScatterOpWithOffset
833836 return failure ();
834837
835838 // The offsets need to be distributed
836- if (dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ())
837- .getShape () !=
838- dyn_cast<VectorType>(adaptor.getMask ().front ().getType ()).getShape ()) {
839+ auto offsetsVecType =
840+ dyn_cast<VectorType>(adaptor.getOffsets ().front ().getType ());
841+ auto maskVecType =
842+ dyn_cast<VectorType>(adaptor.getMask ().front ().getType ());
843+ if (!offsetsVecType || !maskVecType ||
844+ offsetsVecType.getShape () != maskVecType.getShape ()) {
839845 return rewriter.notifyMatchFailure (op,
840846 " offsets have not been distributed" );
841847 }
You can’t perform that action at this time.
0 commit comments