@@ -73,6 +73,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
7373 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
7474}
7575
76+ // Validations for nd instruction arguments is successful if any of these are
77+ // true:
78+ // - tensor descriptor and the output vector shapes exactly match.
79+ // - tensor descriptor has a sg_map attribute and the distributed vector shape
80+ // matches the tensor descriptor shape when scaled using sg_map factors on
81+ // each dimension.
82+ static bool isArgShapesValid (ArrayRef<int64_t > descShape,
83+ ArrayRef<int64_t > valShape, SGMapAttr sgMap) {
84+ if (descShape == valShape) {
85+ if (!sgMap)
86+ return true ;
87+
88+ // this can be relaxed if necessary by supporting non-2d shapes distribution
89+ // until the constraints are defined this lives here instead of the tensor
90+ // descriptor type.
91+ return valShape.size () == sgMap.getWiLayout ().size ();
92+ }
93+
94+ if (!sgMap)
95+ return false ;
96+
97+ if (valShape.size () != descShape.size ())
98+ return false ;
99+
100+ for (const auto &[factor, dim, expected] :
101+ llvm::zip_equal (sgMap.getWiLayout (), valShape, descShape)) {
102+ if (factor * dim != expected)
103+ return false ;
104+ }
105+
106+ return true ;
107+ }
108+
76109// ===----------------------------------------------------------------------===//
77110// XeGPU_CreateNdDescOp
78111// ===----------------------------------------------------------------------===//
@@ -210,13 +243,13 @@ LogicalResult PrefetchNdOp::verify() {
210243 return emitOpError (" Expects a non-scattered TensorDesc.\n " );
211244
212245 if (!isReadHintOrNone (getL1HintAttr ()))
213- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
246+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
214247
215248 if (!isReadHintOrNone (getL2HintAttr ()))
216- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
249+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
217250
218251 if (!isReadHintOrNone (getL3HintAttr ()))
219- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
252+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
220253
221254 return success ();
222255}
@@ -238,13 +271,13 @@ LogicalResult LoadNdOp::verify() {
238271 return emitOpError (" Invalid result, it should be a VectorType.\n " );
239272
240273 if (!isReadHintOrNone (getL1HintAttr ()))
241- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
274+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
242275
243276 if (!isReadHintOrNone (getL2HintAttr ()))
244- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
277+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
245278
246279 if (!isReadHintOrNone (getL3HintAttr ()))
247- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
280+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
248281
249282 auto array_len = tdescTy.getArrayLength ();
250283 auto tdescShape = getShapeOf (tdescTy);
@@ -280,8 +313,9 @@ LogicalResult LoadNdOp::verify() {
280313 auto it = tdescShape.begin ();
281314 tdescShape.insert (it, array_len);
282315 }
316+ auto sgMap = tdescTy.getSGMapAttr ();
283317
284- if (tdescShape != valueShape)
318+ if (! isArgShapesValid ( tdescShape, valueShape, sgMap) )
285319 return emitOpError () << " Result shape doesn't match TensorDesc shape."
286320 << " The expected shape is " << makeString (tdescShape)
287321 << " . But the given shape is "
@@ -303,17 +337,26 @@ LogicalResult StoreNdOp::verify() {
303337 return emitOpError (" Expects a non-scattered TensorDesc.\n " );
304338
305339 if (!valTy)
306- return emitOpError (" Exepcting a VectorType result.\n " );
340+ return emitOpError (" Expecting a VectorType result.\n " );
307341
308342 if (!isWriteHintOrNone (getL1HintAttr ()))
309- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
343+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
310344
311345 if (!isWriteHintOrNone (getL2HintAttr ()))
312- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
346+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
313347
314348 if (!isWriteHintOrNone (getL3HintAttr ()))
315- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
349+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
350+
351+ auto tdescShape = getShapeOf (dstTy);
352+ auto valueShape = getShapeOf (valTy);
353+ auto sgMap = dstTy.getSGMapAttr ();
316354
355+ if (!isArgShapesValid (tdescShape, valueShape, sgMap))
356+ return emitOpError () << " Result shape doesn't match TensorDesc shape."
357+ << " The expected shape is " << makeString (tdescShape)
358+ << " . But the given shape is "
359+ << makeString (valueShape) << " .\n " ;
317360 return success ();
318361}
319362
@@ -423,13 +466,13 @@ LogicalResult PrefetchOp::verify() {
423466 return emitOpError (" Expects a scattered TensorDesc.\n " );
424467
425468 if (!isReadHintOrNone (getL1HintAttr ()))
426- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
469+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
427470
428471 if (!isReadHintOrNone (getL2HintAttr ()))
429- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
472+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
430473
431474 if (!isReadHintOrNone (getL3HintAttr ()))
432- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
475+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
433476
434477 return success ();
435478}
@@ -446,13 +489,13 @@ LogicalResult LoadGatherOp::verify() {
446489 return emitOpError (" Expects a scattered TensorDesc.\n " );
447490
448491 if (!isReadHintOrNone (getL1HintAttr ()))
449- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
492+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
450493
451494 if (!isReadHintOrNone (getL2HintAttr ()))
452- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
495+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
453496
454497 if (!isReadHintOrNone (getL3HintAttr ()))
455- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
498+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
456499
457500 auto tdescElemTy = tdescTy.getElementType ();
458501 auto valueElemTy = getElementType ();
@@ -490,13 +533,13 @@ LogicalResult StoreScatterOp::verify() {
490533 return emitOpError (" Expects a scattered TensorDesc.\n " );
491534
492535 if (!isWriteHintOrNone (getL1HintAttr ()))
493- return emitOpError (" invlid l1_hint: " ) << getL1HintAttr ();
536+ return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
494537
495538 if (!isWriteHintOrNone (getL2HintAttr ()))
496- return emitOpError (" invlid l2_hint: " ) << getL2HintAttr ();
539+ return emitOpError (" invalid l2_hint: " ) << getL2HintAttr ();
497540
498541 if (!isWriteHintOrNone (getL3HintAttr ()))
499- return emitOpError (" invlid l3_hint: " ) << getL3HintAttr ();
542+ return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
500543
501544 auto maskTy = getMaskType ();
502545 auto valueTy = getValueType ();
0 commit comments