Skip to content

Commit 10a6aff

Browse files
committed
Remove redundant braces.
1 parent 953b850 commit 10a6aff

File tree

1 file changed

+20
-34
lines changed

1 file changed

+20
-34
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,13 @@ class CreateNdDescToXeVMPattern
210210
// Get shape values from op fold results.
211211
baseShapeW = createOffset(mixedSizes, rank - 1);
212212
baseShapeH = createOffset(mixedSizes, rank - 2);
213-
if (sourceMemrefTy) {
213+
if (sourceMemrefTy)
214214
// Cast index to i64.
215215
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
216-
} else if (baseAddr.getType() != i64Ty) {
216+
else if (baseAddr.getType() != i64Ty)
217217
// Pointer type may be i32. Cast to i64 if needed.
218218
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
219-
}
219+
220220
// Populate payload.
221221
Value payLoadAsI64 =
222222
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -288,9 +288,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
288288

289289
auto tdesc = adaptor.getTensorDesc();
290290
auto tdescTy = op.getTensorDescType();
291-
if (tdescTy.getRank() != 2) {
291+
if (tdescTy.getRank() != 2)
292292
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
293-
}
294293

295294
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
296295
Value payLoadAsI64 =
@@ -308,10 +307,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
308307
Value offsetH;
309308
auto mixedOffsets = op.getMixedOffsets();
310309
int64_t opOffsetsSize = mixedOffsets.size();
311-
if (opOffsetsSize != 0 && opOffsetsSize != 2) {
310+
if (opOffsetsSize != 0 && opOffsetsSize != 2)
312311
return rewriter.notifyMatchFailure(op,
313312
"Expected 2D offsets or no offsets.");
314-
}
315313
if (opOffsetsSize) {
316314
// If mixed offsets are provided by the op convert them to i32.
317315
offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
@@ -348,10 +346,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
348346
int32_t vblocks = tdescTy.getArrayLength();
349347
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
350348
VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
351-
if (!srcVecTy) {
349+
if (!srcVecTy)
352350
return rewriter.notifyMatchFailure(
353351
op, "Expected store value to be a vector type.");
354-
}
355352
auto storeCacheControl =
356353
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
357354
Value src = adaptor.getValue();
@@ -419,10 +416,9 @@ class CreateDescToXeVMPattern
419416
ConversionPatternRewriter &rewriter) const override {
420417
auto eTy = op.getTensorDescType().getElementType();
421418
auto eBw = eTy.getIntOrFloatBitWidth();
422-
if (eBw % 8 != 0) {
419+
if (eBw % 8 != 0)
423420
return rewriter.notifyMatchFailure(
424421
op, "Expected element type bit width to be multiple of 8.");
425-
}
426422
auto loc = op.getLoc();
427423
// Offsets are provided as scalar i64 by type converter.
428424
auto offsets = adaptor.getOffsets();
@@ -447,10 +443,9 @@ class UpdateOffsetToXeVMPattern
447443
ConversionPatternRewriter &rewriter) const override {
448444
auto eTy = op.getTensorDescType().getElementType();
449445
auto eBw = eTy.getIntOrFloatBitWidth();
450-
if (eBw % 8 != 0) {
446+
if (eBw % 8 != 0)
451447
return rewriter.notifyMatchFailure(
452448
op, "Expected element type bit width to be multiple of 8.");
453-
}
454449
auto loc = op.getLoc();
455450
// Scatter descriptor is provided as scalar i64 by type converter.
456451
// Offsets are provided as scalar i64 by type converter.
@@ -475,30 +470,27 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
475470
Value basePtrI64;
476471
// Load result or Store valye Type can be vector or scalar.
477472
Type valOrResTy;
478-
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
473+
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
479474
valOrResTy = op.getResult().getType();
480-
} else {
475+
else
481476
valOrResTy = adaptor.getValue().getType();
482-
}
483477
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
484478
bool hasScalarVal = !valOrResVecTy;
485479
int64_t elemBitWidth =
486480
hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
487481
: valOrResVecTy.getElementType().getIntOrFloatBitWidth();
488482
// Element type must be multiple of 8 bits.
489-
if (elemBitWidth % 8 != 0) {
483+
if (elemBitWidth % 8 != 0)
490484
return rewriter.notifyMatchFailure(
491485
op, "Expected element type bit width to be multiple of 8.");
492-
}
493486
int64_t elemByteSize = elemBitWidth / 8;
494487
// Default memory space is global.
495488
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
496489
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
497490
// If tensor descriptor is available, we use its memory space.
498-
if (tdescTy) {
491+
if (tdescTy)
499492
ptrTypeLLVM = LLVM::LLVMPointerType::get(
500493
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
501-
}
502494
// Base pointer can come from source (load) or dest (store).
503495
// If they are memrefs, we use their memory space.
504496
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
@@ -524,32 +516,30 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
524516
Value offsets = adaptor.getOffsets();
525517
Value mask = adaptor.getMask();
526518
if (offsets) {
527-
if (dyn_cast<VectorType>(offsets.getType())) {
519+
if (dyn_cast<VectorType>(offsets.getType()))
528520
// Offset needs be scalar. Single element vector is converted to scalar
529521
// by type converter.
530522
return rewriter.notifyMatchFailure(op,
531523
"Expected offsets to be a scalar.");
532-
} else {
524+
else
533525
// If offsets are provided, we add them to the base pointer.
534526
// Offsets are in number of elements, we need to multiply by
535527
// element byte size.
536528
basePtrI64 =
537529
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
538-
}
539530
}
540531
// Convert base pointer (i64) to LLVM pointer type.
541532
Value basePtrLLVM =
542533
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
543534

544535
Value maskForLane;
545536
VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
546-
if (maskVecTy) {
537+
if (maskVecTy)
547538
// Mask needs be scalar. Single element vector is converted to scalar by
548539
// type converter.
549540
return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
550-
} else {
541+
else
551542
maskForLane = mask;
552-
}
553543
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
554544
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
555545
maskForLane, true, true);
@@ -609,10 +599,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
609599
auto tdescTy = op.getTensorDescType();
610600
Value basePtrI64 = adaptor.getSource();
611601
// Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
612-
if (basePtrI64.getType() != rewriter.getI64Type()) {
602+
if (basePtrI64.getType() != rewriter.getI64Type())
613603
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
614604
basePtrI64);
615-
}
616605
Value offsets = adaptor.getOffsets();
617606
if (offsets) {
618607
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
@@ -637,10 +626,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
637626
elemByteSize = *op.getOffsetAlignByte();
638627
}
639628
if (elemBitWidth != 0) {
640-
if (elemBitWidth % 8 != 0) {
629+
if (elemBitWidth % 8 != 0)
641630
return rewriter.notifyMatchFailure(
642631
op, "Expected element type bit width to be multiple of 8.");
643-
}
644632
elemByteSize = elemBitWidth / 8;
645633
}
646634
basePtrI64 =
@@ -651,10 +639,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
651639
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
652640
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
653641
// If tensor descriptor is available, we use its memory space.
654-
if (tdescTy) {
642+
if (tdescTy)
655643
ptrTypeLLVM = LLVM::LLVMPointerType::get(
656644
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
657-
}
658645
// If source is a memref, we use its memory space.
659646
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
660647
auto addrSpace = memRefTy.getMemorySpaceAsInt();
@@ -883,9 +870,8 @@ struct ConvertXeGPUToXeVMPass
883870
return VectorType::get(sum, elemType);
884871
});
885872
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
886-
if (type.isScattered()) {
873+
if (type.isScattered())
887874
return IntegerType::get(&getContext(), 64);
888-
}
889875
auto i32Type = IntegerType::get(&getContext(), 32);
890876
return VectorType::get(8, i32Type);
891877
});

0 commit comments

Comments
 (0)