@@ -81,15 +81,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
8181 auto maskShape = getShapeOf (maskTy);
8282 auto valueShape = getShapeOf (valueTy);
8383 auto tdescShape = getShapeOf (tdescTy);
84- auto chunkSize = tdescTy.getChunkSize ();
84+ auto chunkSize = tdescTy.getChunkSizeAsInt ();
8585
8686 if (valueTy.getElementType () != tdescTy.getElementType ())
8787 return emitError ()
8888 << " Value should have the same element type as TensorDesc." ;
8989
90- if (tdescShape[0 ] != maskShape[0 ])
90+ llvm::SmallVector<int64_t > expectedMaskShape (tdescShape);
91+ if (chunkSize > 1 )
92+ expectedMaskShape.pop_back ();
93+ if (expectedMaskShape != maskShape)
9194 return emitError ()
92- << " dim-0 of the Mask and TensorDesc should be the same ." ;
95+ << " Mask should match TensorDesc except the chunk size dim ." ;
9396
9497 // a valid shape for SIMT case
9598 if (valueTy.getRank () == 1 && valueTy.getNumElements () == chunkSize) {
@@ -203,11 +206,9 @@ LogicalResult CreateNdDescOp::verify() {
203206 " is a memref) should match with each other." );
204207
205208 // check result TensorDesc rank
206- invalidRank = (getType ().getRank () > 2 || getType ().getRank () > rank);
207-
208- if (invalidRank)
209+ if (getType ().getRank () > rank)
209210 return emitOpError (
210- " Expecting the TensorDesc rank is up to 2 and not greater than the "
211+ " Expecting the TensorDesc rank is not greater than the "
211212 " ranks of shape, strides, offsets or the memref source." );
212213
213214 if (invalidElemTy)
@@ -247,12 +248,12 @@ LogicalResult LoadNdOp::verify() {
247248 auto tdescTy = getTensorDescType ();
248249 auto valueTy = getType ();
249250
250- if (tdescTy.getRank () > 2 )
251- return emitOpError (" Expecting a 1D/2D TensorDesc.\n " );
252-
253251 if (tdescTy.isScattered ())
254252 return emitOpError (" Expects a non-scattered TensorDesc.\n " );
255253
254+ if (tdescTy.getRank () > 2 )
255+ return emitOpError (" Expects a 1D or 2D TensorDesc.\n " );
256+
256257 if (!valueTy)
257258 return emitOpError (" Invalid result, it should be a VectorType.\n " );
258259
@@ -316,15 +317,13 @@ LogicalResult LoadNdOp::verify() {
316317 }
317318
318319 auto array_len = tdescTy.getArrayLength ();
319- if (array_len > 1 ) {
320+ if (array_len > 1 )
320321 tdescShape.insert (tdescShape.begin (), array_len);
321- }
322322
323- if (tdescShape != valueShape) {
323+ if (tdescShape != valueShape)
324324 return emitOpError () << " Result shape " << makeString (valueShape)
325325 << " is not consistent with tensor descriptor "
326326 << tdescTy;
327- }
328327
329328 return success ();
330329}
@@ -336,12 +335,12 @@ LogicalResult StoreNdOp::verify() {
336335 auto dstTy = getTensorDescType (); // Tile
337336 auto valTy = getValueType (); // Vector
338337
339- if (dstTy.getRank () > 2 )
340- return emitOpError (" Expecting a 1D/2D TensorDesc.\n " );
341-
342338 if (dstTy.isScattered ())
343339 return emitOpError (" Expects a non-scattered TensorDesc.\n " );
344340
341+ if (dstTy.getRank () > 2 )
342+ return emitOpError (" Expects a 1D or 2D TensorDesc.\n " );
343+
345344 if (!valTy)
346345 return emitOpError (" Expecting a VectorType result.\n " );
347346
@@ -370,22 +369,21 @@ LogicalResult StoreNdOp::verify() {
370369 return emitOpError ()
371370 << " TensorDesc doesn't need LayoutAttr for SIMT code" ;
372371
373- if (tdescElems % valueElems) {
372+ if (tdescElems % valueElems)
374373 return emitOpError ()
375374 << " Value shape " << makeString (getShapeOf (valTy))
376375 << " is not a valid distribution for tensor descriptor " << dstTy;
377- }
376+
378377 return success ();
379378 }
380379
381380 // SIMD code should have the same shape as the tensor descriptor.
382381 auto tdescShape = getShapeOf (dstTy);
383382 auto valueShape = getShapeOf (valTy);
384- if (tdescShape != valueShape) {
383+ if (tdescShape != valueShape)
385384 return emitOpError () << " Value shape " << makeString (valueShape)
386385 << " is not consistent with tensor descriptor "
387386 << dstTy;
388- }
389387
390388 return success ();
391389}
@@ -449,25 +447,8 @@ LogicalResult CreateDescOp::verify() {
449447 << " , TensorDesc: " << tdescMemorySpace;
450448
451449 // check total size
452- auto chunkSize = tdescTy.getChunkSize ();
453- auto elemBits = tdescTy.getElementType ().getIntOrFloatBitWidth ();
454- auto bitsPerLane = elemBits * chunkSize;
455- if (chunkSize > 1 && bitsPerLane % 32 ) {
456- // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
457- // For 32-bit data, the hardware can support larger larger chunk size. So
458- // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
459- // But this requires the total size is 32 bit aligned to make the
460- // optimization work.
461- return emitOpError (
462- " access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned." );
463- }
464-
465- auto lscConstraints = 512 * 8 ; // each access is upto 512 bytes.
466- if (elemBits * tdescTy.getNumElements () > lscConstraints)
467- return emitOpError (" total access size (simd_lanes * chunk_size * "
468- " sizeof(elemTy)) is upto 512 bytes." );
469-
470- SmallVector<int64_t > shape ({(int64_t )getNumOffsets ()});
450+ auto chunkSize = tdescTy.getChunkSizeAsInt ();
451+ SmallVector<int64_t > shape (getOffsetsType ().getShape ());
471452 if (chunkSize != 1 )
472453 shape.push_back (chunkSize);
473454
@@ -563,6 +544,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
563544 build (builder, state, tensorDesc, ofrs);
564545}
565546
547+ LogicalResult UpdateOffsetOp::verify () {
548+ auto tdescTy = getTensorDescType ();
549+ if (!tdescTy.isScattered ())
550+ return emitOpError (" Expects a scattered TensorDesc.\n " );
551+
552+ SmallVector<int64_t > expectedOffsetShape = getShapeOf (tdescTy);
553+ SmallVector<int64_t > offsetShape = getShapeOf (getOffsetsType ());
554+ if (tdescTy.getChunkSizeAsInt () > 1 )
555+ expectedOffsetShape.pop_back ();
556+
557+ if (expectedOffsetShape != offsetShape)
558+ return emitOpError (
559+ " Offsets should match TensorDesc except the chunk size dim." );
560+
561+ return success ();
562+ }
563+
566564// ===----------------------------------------------------------------------===//
567565// XeGPU_DpasOp
568566// ===----------------------------------------------------------------------===//
0 commit comments