@@ -110,6 +110,66 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
110110 return success ();
111111}
112112
113+ static LogicalResult
114+ isValidGatherScatterMemRefParams (Type maskTy, VectorType valueTy,
115+ MemRefType memTy, int64_t chunkSize,
116+ function_ref<InFlightDiagnostic()> emitError) {
117+
118+ if (!valueTy)
119+ return emitError () << " Expecting a vector type result." ;
120+
121+ auto maskShape = getShapeOf (maskTy);
122+ auto valueShape = getShapeOf (valueTy);
123+ auto memShape = getShapeOf (memTy);
124+
125+ if (valueTy.getElementType () != memTy.getElementType ())
126+ return emitError () << " Value should have the same element type as MemRef." ;
127+
128+ // a valid shape for SIMT case
129+ if (valueTy.getRank () == 1 ) {
130+ if (valueTy.getNumElements () != chunkSize)
131+ return emitError () << " value elements must match chunk size " << chunkSize
132+ << " for SIMT code." ;
133+ return success ();
134+ }
135+
136+ llvm::SmallVector<int64_t > expectedMaskShape (valueShape);
137+ if (chunkSize > 1 )
138+ expectedMaskShape.pop_back ();
139+ if (expectedMaskShape != maskShape)
140+ return emitError () << " Mask should match value except the chunk size dim." ;
141+
142+ return success ();
143+ }
144+
145+ static LogicalResult
146+ isValidGatherScatterRawptrParams (Type maskTy, VectorType valueTy,
147+ int64_t chunkSize,
148+ function_ref<InFlightDiagnostic()> emitError) {
149+
150+ if (!valueTy)
151+ return emitError () << " Expecting a vector type result." ;
152+
153+ auto maskShape = getShapeOf (maskTy);
154+ auto valueShape = getShapeOf (valueTy);
155+
156+ // a valid shape for SIMT case
157+ if (valueTy.getRank () == 1 ) {
158+ if (valueTy.getNumElements () != chunkSize)
159+ return emitError () << " value elements must match chunk size " << chunkSize
160+ << " for SIMT code." ;
161+ return success ();
162+ }
163+
164+ llvm::SmallVector<int64_t > expectedMaskShape (valueShape);
165+ if (chunkSize > 1 )
166+ expectedMaskShape.pop_back ();
167+ if (expectedMaskShape != maskShape)
168+ return emitError () << " Mask should match value except the chunk size dim." ;
169+
170+ return success ();
171+ }
172+
113173// ===----------------------------------------------------------------------===//
114174// XeGPU_CreateNdDescOp
115175// ===----------------------------------------------------------------------===//
@@ -683,17 +743,27 @@ LogicalResult LoadGatherOp::verify() {
683743 if (!isReadHintOrNone (getL3HintAttr ()))
684744 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
685745
686- return isValidGatherScatterParams (maskTy, valueTy, tdescTy,
687- [&]() { return emitOpError (); });
746+ if (tdescTy)
747+ return isValidGatherScatterParams (maskTy, valueTy, tdescTy,
748+ [&]() { return emitOpError (); });
749+ auto srcTy = getSourceType ();
750+ uint64_t chunkSize = static_cast <int64_t >(getChunkSize ().value_or (1 ));
751+ auto memTy = dyn_cast<MemRefType>(srcTy);
752+
753+ if (memTy)
754+ return isValidGatherScatterMemRefParams (maskTy, valueTy, memTy, chunkSize,
755+ [&]() { return emitOpError (); });
756+ return isValidGatherScatterRawptrParams (maskTy, valueTy, chunkSize,
757+ [&]() { return emitOpError (); });
688758}
689759
690760void LoadGatherOp::build (OpBuilder &builder, OperationState &state,
691761 Type valueType, Value source, Value mask,
692762 xegpu::CachePolicyAttr l1_hint,
693763 xegpu::CachePolicyAttr l2_hint,
694764 xegpu::CachePolicyAttr l3_hint) {
695- build (builder, state, valueType, source, ValueRange (), DenseI64ArrayAttr (),
696- mask, IntegerAttr (), l1_hint, l2_hint, l3_hint);
765+ build (builder, state, valueType, source, Value (), mask, IntegerAttr (),
766+ l1_hint, l2_hint, l3_hint);
697767}
698768
699769// ===----------------------------------------------------------------------===//
@@ -713,17 +783,28 @@ LogicalResult StoreScatterOp::verify() {
713783 if (!isWriteHintOrNone (getL3HintAttr ()))
714784 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
715785
716- return isValidGatherScatterParams (maskTy, valueTy, tdescTy,
717- [&]() { return emitOpError (); });
786+ if (tdescTy)
787+ return isValidGatherScatterParams (maskTy, valueTy, tdescTy,
788+ [&]() { return emitOpError (); });
789+
790+ auto destTy = getDestType ();
791+ uint64_t chunkSize = static_cast <int64_t >(getChunkSize ().value_or (1 ));
792+ auto memTy = dyn_cast<MemRefType>(destTy);
793+
794+ if (memTy)
795+ return isValidGatherScatterMemRefParams (maskTy, valueTy, memTy, chunkSize,
796+ [&]() { return emitOpError (); });
797+ return isValidGatherScatterRawptrParams (maskTy, valueTy, chunkSize,
798+ [&]() { return emitOpError (); });
718799}
719800
720801void StoreScatterOp::build (OpBuilder &builder, OperationState &state,
721802 Value value, Value dest, Value mask,
722803 xegpu::CachePolicyAttr l1_hint,
723804 xegpu::CachePolicyAttr l2_hint,
724805 xegpu::CachePolicyAttr l3_hint) {
725- build (builder, state, value, dest, ValueRange (), DenseI64ArrayAttr (), mask ,
726- IntegerAttr (), l1_hint, l2_hint, l3_hint);
806+ build (builder, state, value, dest, Value (), mask, IntegerAttr (), l1_hint ,
807+ l2_hint, l3_hint);
727808}
728809
729810// ===----------------------------------------------------------------------===//
0 commit comments