Skip to content

Commit b0f626b

Browse files
committed
cleanup
1 parent a2bc905 commit b0f626b

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
265265
}
266266

267267
LogicalResult CreateNdDescOp::verify() {
268-
int64_t rank = getMixedSizes().size();
269-
bool invalidRank = false;
268+
size_t rank = getMixedSizes().size();
269+
bool invalidRank = rank != getMixedStrides().size();
270270
bool invalidElemTy = false;
271271

272272
// Memory space of created TensorDesc should match with the source.
@@ -280,16 +280,13 @@ LogicalResult CreateNdDescOp::verify() {
280280
<< " Source: " << srcMemorySpace
281281
<< ", TensorDesc: " << tdescMemorySpace;
282282

283-
if (int64_t offsetRank = getMixedOffsets().size())
283+
if (size_t offsetRank = getMixedOffsets().size())
284284
invalidRank |= (offsetRank != rank);
285285

286286
// check source type matches the rank if it is a memref.
287287
// It also should have the same ElementType as TensorDesc.
288-
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
289-
if (memrefTy) {
290-
invalidRank |= (memrefTy.getRank() != rank);
288+
if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
291289
invalidElemTy |= memrefTy.getElementType() != getElementType();
292-
}
293290

294291
if (llvm::isa<IntegerType>(getSourceType())) {
295292
// strides and shape must present for integer source.
@@ -298,16 +295,13 @@ LogicalResult CreateNdDescOp::verify() {
298295
"integer source.");
299296
}
300297

301-
// mismatches among shape, strides, and offsets are
302-
// already handeled by OffsetSizeAndStrideOpInterface.
303-
// So they are not check here.
304298
if (invalidRank)
305299
return emitOpError(
306300
"Expecting the rank of shape, strides, offsets, and source (if source "
307301
"is a memref) should match with each other.");
308302

309303
// check result TensorDesc rank
310-
if (getType().getRank() > rank)
304+
if (getType().getRank() > (int64_t)rank)
311305
return emitOpError(
312306
"Expecting the TensorDesc rank is not greater than the "
313307
"ranks of shape, strides, offsets or the memref source.");

0 commit comments

Comments
 (0)