@@ -101,6 +101,48 @@ static bool isEvenDistributed(llvm::ArrayRef<int64_t> shape,
101101 return true ;
102102}
103103
104+ static LogicalResult isValidGatherScatterParams (Type maskTy, VectorType valueTy, TensorDescType tdescTy, UnitAttr transposeAttr, function_ref<InFlightDiagnostic()> emitError) {
105+
106+ if (!tdescTy.isScattered ())
107+ return emitError () << " Expects a scattered TensorDesc." ;
108+
109+ if (!valueTy)
110+ return emitError () << " Expecting a vector type result." ;
111+
112+ auto maskShape = getShapeOf (maskTy);
113+ auto valueShape = getShapeOf (valueTy);
114+ auto tdescShape = getShapeOf (tdescTy);
115+ auto chunkSize = tdescTy.getChunkSize ();
116+
117+ if (valueTy.getElementType () != tdescTy.getElementType ())
118+ return emitError () << " Value should have the same element type as TensorDesc." ;
119+
120+ if (tdescShape[0 ] != maskShape[0 ])
121+ return emitError () << " dim-0 of the Mask and TensorDesc should be the same." ;
122+
123+ // a valid shape for SIMT case
124+ if (valueTy.getRank () == 1 && valueTy.getNumElements () == chunkSize) {
125+ if (tdescTy.getLayoutAttr ())
126+ return emitError () << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
127+ if (transposeAttr)
128+ return emitError () << " doesn't need TransposeAttr for SIMT code" ;
129+ return success ();
130+ }
131+
132+ if (tdescTy.getRank () == 2 && valueTy.getRank () == 2 ) {
133+ if (!transposeAttr)
134+ return emitError () << " rank-2 tensor has to be transposed." ;
135+ transpose ({1 , 0 }, tdescShape);
136+ }
137+
138+ if (tdescShape != valueShape)
139+ return emitError () << " Value shape " << makeString (valueShape)
140+ << " is neither a valid distribution for SIMT nor "
141+ " consistent with the tensor descriptor for SIMD "
142+ << tdescTy;
143+ return success ();
144+ }
145+
104146// ===----------------------------------------------------------------------===//
105147// XeGPU_CreateNdDescOp
106148// ===----------------------------------------------------------------------===//
@@ -517,12 +559,6 @@ LogicalResult LoadGatherOp::verify() {
517559 auto maskTy = getMaskType ();
518560 auto valueTy = getValueType ();
519561
520- if (!valueTy)
521- return emitOpError (" Expecting a vector type result.\n " );
522-
523- if (!tdescTy.isScattered ())
524- return emitOpError (" Expects a scattered TensorDesc.\n " );
525-
526562 if (!isReadHintOrNone (getL1HintAttr ()))
527563 return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
528564
@@ -532,52 +568,17 @@ LogicalResult LoadGatherOp::verify() {
532568 if (!isReadHintOrNone (getL3HintAttr ()))
533569 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
534570
535- auto tdescElemTy = tdescTy.getElementType ();
536- auto valueElemTy = getElementType ();
537- if (tdescElemTy != valueElemTy)
538- return emitOpError (
539- " Value should have the same element type as TensorDesc." );
540-
541- auto maskShape = getShapeOf (maskTy);
542- auto valueShape = getShapeOf (valueTy);
543- auto tdescShape = getShapeOf (tdescTy);
544-
545- if (tdescShape[0 ] != maskShape[0 ])
546- return emitOpError (" dim-0 of the Mask and TensorDesc should be the same." );
547-
548- auto chunkSize = tdescTy.getChunkSize ();
549-
550- // a valid shape for SIMT case
551- if (valueTy.getRank () == 1 && valueTy.getNumElements () == chunkSize) {
552- if (tdescTy.getLayoutAttr ())
553- return emitOpError ()
554- << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
555- if (getTransposeAttr ())
556- return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
557- return success ();
558- }
559-
560- if (tdescTy.getRank () == 2 && valueTy.getRank () == 2 ) {
561- if (!getTransposeAttr ())
562- return emitOpError (" load of rank-2 tensor has to be transposed." );
563- transpose ({1 , 0 }, tdescShape);
564- }
565-
566- if (tdescShape != valueShape)
567- return emitOpError () << " Result shape " << makeString (valueShape)
568- << " is neither a valid distribution for SIMT nor "
569- " consistent with the tensor descriptor for SIMD "
570- << tdescTy;
571- return success ();
571+ return isValidGatherScatterParams (maskTy, valueTy, tdescTy, getTransposeAttr (),
572+ [&]() { return emitOpError (); });
572573}
573574
574575// ===----------------------------------------------------------------------===//
575576// XeGPU_StoreScatterOp
576577// ===----------------------------------------------------------------------===//
577578LogicalResult StoreScatterOp::verify () {
578579 auto tdescTy = getTensorDescType ();
579- if (!tdescTy. isScattered ())
580- return emitOpError ( " Expects a scattered TensorDesc. \n " );
580+ auto maskTy = getMaskType ();
581+ auto valueTy = getValueType ( );
581582
582583 if (!isWriteHintOrNone (getL1HintAttr ()))
583584 return emitOpError (" invalid l1_hint: " ) << getL1HintAttr ();
@@ -588,43 +589,8 @@ LogicalResult StoreScatterOp::verify() {
588589 if (!isWriteHintOrNone (getL3HintAttr ()))
589590 return emitOpError (" invalid l3_hint: " ) << getL3HintAttr ();
590591
591- auto maskTy = getMaskType ();
592- auto valueTy = getValueType ();
593-
594- if (!valueTy)
595- return emitOpError (" Expecting a vector type for the value.\n " );
596-
597- auto maskShape = getShapeOf (maskTy);
598- auto tdescShape = getShapeOf (tdescTy);
599- auto valueShape = getShapeOf (valueTy);
600- if (tdescShape[0 ] != maskShape[0 ])
601- return emitOpError (" dim-0 of the Mask and TensorDesc should be the same." );
602-
603- auto chunkSize = tdescTy.getChunkSize ();
604-
605- // a valid shape for SIMT case
606- if (valueTy.getRank () == 1 && valueTy.getNumElements () == chunkSize) {
607- if (tdescTy.getLayoutAttr ())
608- return emitOpError ()
609- << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
610- if (getTransposeAttr ())
611- return emitOpError () << " doesn't need TransposeAttr for SIMT code" ;
612- return success ();
613- }
614-
615- if (tdescTy.getRank () == 2 && valueTy.getRank () == 2 ) {
616- if (!getTransposeAttr ())
617- return emitOpError (" Store of a rank-2 tensor has to be transposed." );
618- transpose ({1 , 0 }, tdescShape);
619- }
620-
621- if (tdescShape != valueShape)
622- return emitOpError () << " Value shape " << makeString (valueShape)
623- << " is neither a valid distribution for SIMT nor "
624- " consistent with the tensor descriptor for SIMD "
625- << tdescTy;
626-
627- return success ();
592+ return isValidGatherScatterParams (maskTy, valueTy, tdescTy, getTransposeAttr (),
593+ [&]() { return emitOpError (); });
628594}
629595
630596// ===----------------------------------------------------------------------===//
@@ -660,14 +626,18 @@ LogicalResult DpasOp::verify() {
660626 auto rhsShape = getRhsType ().getShape ();
661627 auto resShape = getResultType ().getShape ();
662628
663- if (getAcc ()) {
664- if (getAcc ().getType () != getResultType ())
665- return emitOpError (" Expecting the acc type to be the same as result." );
666- }
629+ if (getAcc () && getAcc ().getType () != getResultType ())
630+ return emitOpError (" Expecting the acc type to be the same as result." );
667631
668- // SIMT code: skip the check since lack of semantic info at this level.
632+ // SIMT code: the size of the B operand has to be a multiple of 32 bits.
633+ // It skips the semantic check since lack of architecture information.
669634 // Users need to ensure the correctness.
670635 if (lhsRank == 1 && rhsRank == 1 && resRank == 1 ) {
636+ auto numElems = getRhsType ().getNumElements ();
637+ auto elemTy = getRhsType ().getElementType ();
638+ auto factor = 32 / elemTy.getIntOrFloatBitWidth ();
639+ if (numElems % factor != 0 )
640+ return emitOpError (" Expecting B operand to be a multiple of 32 bits." );
671641 return success ();
672642 } else { // SIMD code
673643 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3 ) || resRank != 2 )
0 commit comments