@@ -1245,73 +1245,88 @@ LogicalResult GatherOp::inferReturnTypes(
12451245}
12461246
12471247// -- DescriptorGatherOp
1248- LogicalResult DescriptorGatherOp::verifyResultType (Operation *op,
1249- mlir::ShapedType type) {
1250- if (type.getRank () != 2 )
1251- return op->emitOpError (" result must be a 2D tensor, but got " ) << type;
1248+ LogicalResult
1249+ DescriptorGatherOp::verifyResultType (Operation *op, ShapedType resultType,
1250+ RankedTensorType indicesType) {
1251+ if (indicesType.getRank () != 1 )
1252+ return op->emitOpError (" x offsets must be a 1D tensor, but got " )
1253+ << indicesType;
1254+ if (resultType.getRank () != 2 )
1255+ return op->emitOpError (" result must be a 2D tensor, but got " )
1256+ << resultType;
12521257
12531258 // The swizzling of TMA accesses matches that of the MMAv3 shared memory
12541259 // layouts. However, these have minimum size requirements.
12551260 // TODO: We can support smaller gather sizes by padding the `local_alloc` this
12561261 // lowers to to the nearest minimum tile size.
1257- if (unsigned rows = type .getShape ()[0 ]; rows < 8 ) {
1262+ if (unsigned rows = resultType .getShape ()[0 ]; rows < 8 ) {
12581263 return op->emitOpError (" gather must have at least 8 rows, but got " )
12591264 << rows;
12601265 }
12611266
1262- Type dtype = type .getElementType ();
1267+ Type dtype = resultType .getElementType ();
12631268 if (dtype.getIntOrFloatBitWidth () > 32 )
12641269 return op->emitOpError (" TMA dtype cannot be greater than 32 bits" );
12651270
12661271 unsigned minCols = 32 / dtype.getIntOrFloatBitWidth () * 8 ;
1267- if (unsigned cols = type .getShape ()[1 ]; cols < minCols) {
1272+ if (unsigned cols = resultType .getShape ()[1 ]; cols < minCols) {
12681273 return op->emitOpError (" gather of " )
12691274 << dtype << " must have at least " << minCols << " columns, but got "
12701275 << cols;
12711276 }
12721277
1278+ if (resultType.getShape ()[0 ] != indicesType.getShape ()[0 ]) {
1279+ return op->emitOpError (" result tensor must have as many rows as indices (" )
1280+ << indicesType.getShape ()[0 ] << " ), but got " << resultType;
1281+ }
1282+
12731283 return success ();
12741284}
12751285
1276- LogicalResult DescriptorGatherOp::verify () {
1277- RankedTensorType blockType = getDesc ().getType ().getBlockType ();
1286+ static LogicalResult verifyGatherScatterOp (Operation *op,
1287+ RankedTensorType blockType,
1288+ RankedTensorType resultType,
1289+ RankedTensorType indicesType) {
12781290 // Gather from `!tt.tensordesc<tensor<1xMxdtype>>`.
1279- if (blockType.getRank () != 2 )
1280- return emitOpError (" block must be a 2D tensor, but got " ) << blockType;
1281- if (blockType.getShape ()[0 ] != 1 )
1282- return emitOpError (" block must have exactly 1 row, but got " ) << blockType;
1283-
1284- // With x offsets `tensor<Nxinttype>`.
1285- RankedTensorType indicesType = getXOffsets ().getType ();
1286- if (indicesType.getRank () != 1 )
1287- return emitOpError (" x offsets must be a 1D tensor, but got " )
1288- << indicesType;
1291+ if (blockType.getRank () != 2 ) {
1292+ return op->emitOpError (" block must be a 2D tensor, but got " ) << blockType;
1293+ }
1294+ if (blockType.getShape ()[0 ] != 1 ) {
1295+ return op->emitOpError (" block must have exactly 1 row, but got " )
1296+ << blockType;
1297+ }
12891298
1290- // Into `tensor<NxMxdtype>`.
1291- RankedTensorType resultType = getType ();
1292- if (failed (verifyResultType (*this , resultType)))
1299+ // With x offsets `tensor<Nxinttype>` into `tensor<NxMxdtype>`.
1300+ if (failed (DescriptorGatherOp::verifyResultType (op, resultType, indicesType)))
12931301 return failure ();
12941302
1295- if (resultType.getShape ()[0 ] != indicesType.getShape ()[0 ]) {
1296- return emitOpError (" result tensor must have as many rows as indices (" )
1297- << indicesType.getShape ()[0 ] << " ), but got " << resultType;
1298- }
12991303 if (resultType.getShape ()[1 ] != blockType.getShape ()[1 ]) {
1300- return emitOpError (" result tensor number of columns must match block (" )
1304+ return op-> emitOpError (" result tensor number of columns must match block (" )
13011305 << blockType.getShape ()[1 ] << " ), but got " << resultType;
13021306 }
13031307 if (resultType.getElementType () != blockType.getElementType ()) {
1304- return emitOpError (" result tensor element type must match block (" )
1308+ return op-> emitOpError (" result tensor element type must match block (" )
13051309 << blockType.getElementType () << " ), but got " << resultType;
13061310 }
13071311
13081312 return success ();
13091313}
13101314
1315+ LogicalResult DescriptorGatherOp::verify () {
1316+ return verifyGatherScatterOp (*this , getDesc ().getType ().getBlockType (),
1317+ getResult ().getType (), getXOffsets ().getType ());
1318+ }
1319+
1320+ // -- DescriptorScatterOp --
1321+ LogicalResult DescriptorScatterOp::verify () {
1322+ return verifyGatherScatterOp (*this , getDesc ().getType ().getBlockType (),
1323+ getSrc ().getType (), getXOffsets ().getType ());
1324+ }
1325+
13111326// -- DescriptorLoadOp --
1312- static LogicalResult verifyDesciptorLoadStoreType (Operation *op,
1313- TensorDescType desc,
1314- RankedTensorType tensor) {
1327+ static LogicalResult verifyDescriptorLoadStoreType (Operation *op,
1328+ TensorDescType desc,
1329+ RankedTensorType tensor) {
13151330 RankedTensorType block = desc.getBlockType ();
13161331 ArrayRef<int64_t > blockShape = block.getShape ();
13171332 ArrayRef<int64_t > tensorShape = tensor.getShape ();
@@ -1328,17 +1343,17 @@ static LogicalResult verifyDesciptorLoadStoreType(Operation *op,
13281343 if (blockShape == tensorShape &&
13291344 block.getElementType () == tensor.getElementType ())
13301345 return success ();
1331- return op->emitOpError (" tensor desciptor block and tensor types must match" );
1346+ return op->emitOpError (" tensor descriptor block and tensor types must match" );
13321347}
13331348
13341349LogicalResult DescriptorLoadOp::verify () {
1335- return verifyDesciptorLoadStoreType (*this , getDesc ().getType (), getType ());
1350+ return verifyDescriptorLoadStoreType (*this , getDesc ().getType (), getType ());
13361351}
13371352
13381353// -- DescriptorStoreOp --
13391354LogicalResult DescriptorStoreOp::verify () {
1340- return verifyDesciptorLoadStoreType (*this , getDesc ().getType (),
1341- getSrc ().getType ());
1355+ return verifyDescriptorLoadStoreType (*this , getDesc ().getType (),
1356+ getSrc ().getType ());
13421357}
13431358
13441359// -- ExperimentalTensormapCreateOp --
0 commit comments