@@ -73,6 +73,29 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
7373 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
7474}
7575
76+ // Validations for nd instruction arguments is successful if any of these are
77+ // true:
78+ // - tensor descriptor and the output vector shapes exactly match.
79+ // - tensor descriptor has a sg_map attribute and the distributed vector shape
80+ // matches the tensor descriptor shape when scaled using sg_map factors on
81+ // each dimension.
82+ static bool isArgShapesValid (ArrayRef<int64_t > descShape,
83+ ArrayRef<int64_t > valShape, SGMapAttr sgMap) {
84+ if (descShape == valShape)
85+ return true ;
86+
87+ if (!sgMap)
88+ return false ;
89+
90+ for (const auto &[factor, dim, expected] :
91+ llvm::zip_equal (sgMap.getWiLayout (), valShape, descShape)) {
92+ if (factor * dim != expected)
93+ return false ;
94+ }
95+
96+ return true ;
97+ }
98+
7699// ===----------------------------------------------------------------------===//
77100// XeGPU_CreateNdDescOp
78101// ===----------------------------------------------------------------------===//
@@ -210,13 +233,13 @@ LogicalResult PrefetchNdOp::verify() {
210233 return emitOpError (" Expects a non-scattered TensorDesc.\n " );
211234
212235 if (!isReadHintOrNone (getL1HintAttr ()))
213- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
236+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
214237
215238 if (!isReadHintOrNone (getL2HintAttr ()))
216- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
239+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
217240
218241 if (!isReadHintOrNone (getL3HintAttr ()))
219- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
242+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
220243
221244 return success ();
222245}
@@ -238,13 +261,13 @@ LogicalResult LoadNdOp::verify() {
238261 return emitOpError (" Invalid result, it should be a VectorType.\n " );
239262
240263 if (!isReadHintOrNone (getL1HintAttr ()))
241- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
264+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
242265
243266 if (!isReadHintOrNone (getL2HintAttr ()))
244- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
267+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
245268
246269 if (!isReadHintOrNone (getL3HintAttr ()))
247- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
270+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
248271
249272 auto array_len = tdescTy.getArrayLength ();
250273 auto tdescShape = getShapeOf (tdescTy);
@@ -280,8 +303,9 @@ LogicalResult LoadNdOp::verify() {
280303 auto it = tdescShape.begin ();
281304 tdescShape.insert (it, array_len);
282305 }
306+ auto sgMap = tdescTy.getSGMapAttr ();
283307
284- if (tdescShape != valueShape)
308+ if (! isArgShapesValid ( tdescShape, valueShape, sgMap) )
285309 return emitOpError () << " Result shape doesn't match TensorDesc shape."
286310 << " The expected shape is " << makeString (tdescShape)
287311 << " . But the given shape is "
@@ -303,17 +327,26 @@ LogicalResult StoreNdOp::verify() {
303327 return emitOpError (" Expects a non-scattered TensorDesc.\n " );
304328
305329 if (!valTy)
306- return emitOpError (" Exepcting a VectorType result.\n " );
330+ return emitOpError (" Expecting a VectorType result.\n " );
307331
308332 if (!isWriteHintOrNone (getL1HintAttr ()))
309- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
333+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
310334
311335 if (!isWriteHintOrNone (getL2HintAttr ()))
312- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
336+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
313337
314338 if (!isWriteHintOrNone (getL3HintAttr ()))
315- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
339+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
340+
341+ auto tdescShape = getShapeOf (dstTy);
342+ auto valueShape = getShapeOf (valTy);
343+ auto sgMap = dstTy.getSGMapAttr ();
316344
345+ if (!isArgShapesValid (tdescShape, valueShape, sgMap))
346+ return emitOpError () << " Result shape doesn't match TensorDesc shape."
347+ << " The expected shape is " << makeString (tdescShape)
348+ << " . But the given shape is "
349+ << makeString (valueShape) << " .\n " ;
317350 return success ();
318351}
319352
@@ -423,13 +456,13 @@ LogicalResult PrefetchOp::verify() {
423456 return emitOpError (" Expects a scattered TensorDesc.\n " );
424457
425458 if (!isReadHintOrNone (getL1HintAttr ()))
426- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
459+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
427460
428461 if (!isReadHintOrNone (getL2HintAttr ()))
429- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
462+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
430463
431464 if (!isReadHintOrNone (getL3HintAttr ()))
432- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
465+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
433466
434467 return success ();
435468}
@@ -446,13 +479,13 @@ LogicalResult LoadGatherOp::verify() {
446479 return emitOpError (" Expects a scattered TensorDesc.\n " );
447480
448481 if (!isReadHintOrNone (getL1HintAttr ()))
449- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
482+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
450483
451484 if (!isReadHintOrNone (getL2HintAttr ()))
452- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
485+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
453486
454487 if (!isReadHintOrNone (getL3HintAttr ()))
455- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
488+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
456489
457490 auto tdescElemTy = tdescTy.getElementType ();
458491 auto valueElemTy = getElementType ();
@@ -490,13 +523,13 @@ LogicalResult StoreScatterOp::verify() {
490523 return emitOpError (" Expects a scattered TensorDesc.\n " );
491524
492525 if (!isWriteHintOrNone (getL1HintAttr ()))
493- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
526+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
494527
495528 if (!isWriteHintOrNone (getL2HintAttr ()))
496- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
529+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
497530
498531 if (!isWriteHintOrNone (getL3HintAttr ()))
499- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
532+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
500533
501534 auto maskTy = getMaskType ();
502535 auto valueTy = getValueType ();
0 commit comments