@@ -210,13 +210,13 @@ class CreateNdDescToXeVMPattern
210210 // Get shape values from op fold results.
211211 baseShapeW = createOffset (mixedSizes, rank - 1 );
212212 baseShapeH = createOffset (mixedSizes, rank - 2 );
213- if (sourceMemrefTy) {
213+ if (sourceMemrefTy)
214214 // Cast index to i64.
215215 baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
216- } else if (baseAddr.getType () != i64Ty) {
216+ else if (baseAddr.getType () != i64Ty)
217217 // Pointer type may be i32. Cast to i64 if needed.
218218 baseAddr = arith::ExtUIOp::create (rewriter, loc, i64Ty, baseAddr);
219- }
219+
220220 // Populate payload.
221221 Value payLoadAsI64 =
222222 vector::BitCastOp::create (rewriter, loc, payloadI64Ty, payload);
@@ -288,9 +288,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
288288
289289 auto tdesc = adaptor.getTensorDesc ();
290290 auto tdescTy = op.getTensorDescType ();
291- if (tdescTy.getRank () != 2 ) {
291+ if (tdescTy.getRank () != 2 )
292292 return rewriter.notifyMatchFailure (op, " Expected 2D tensor descriptor." );
293- }
294293
295294 VectorType payloadI64Ty = VectorType::get (4 , rewriter.getI64Type ());
296295 Value payLoadAsI64 =
@@ -308,10 +307,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
308307 Value offsetH;
309308 auto mixedOffsets = op.getMixedOffsets ();
310309 int64_t opOffsetsSize = mixedOffsets.size ();
311- if (opOffsetsSize != 0 && opOffsetsSize != 2 ) {
310+ if (opOffsetsSize != 0 && opOffsetsSize != 2 )
312311 return rewriter.notifyMatchFailure (op,
313312 " Expected 2D offsets or no offsets." );
314- }
315313 if (opOffsetsSize) {
316314 // If mixed offsets are provided by the op convert them to i32.
317315 offsetW = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
@@ -348,10 +346,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
348346 int32_t vblocks = tdescTy.getArrayLength ();
349347 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
350348 VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue ().getType ());
351- if (!srcVecTy) {
349+ if (!srcVecTy)
352350 return rewriter.notifyMatchFailure (
353351 op, " Expected store value to be a vector type." );
354- }
355352 auto storeCacheControl =
356353 translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
357354 Value src = adaptor.getValue ();
@@ -419,10 +416,9 @@ class CreateDescToXeVMPattern
419416 ConversionPatternRewriter &rewriter) const override {
420417 auto eTy = op.getTensorDescType ().getElementType ();
421418 auto eBw = eTy.getIntOrFloatBitWidth ();
422- if (eBw % 8 != 0 ) {
419+ if (eBw % 8 != 0 )
423420 return rewriter.notifyMatchFailure (
424421 op, " Expected element type bit width to be multiple of 8." );
425- }
426422 auto loc = op.getLoc ();
427423 // Offsets are provided as scalar i64 by type converter.
428424 auto offsets = adaptor.getOffsets ();
@@ -447,10 +443,9 @@ class UpdateOffsetToXeVMPattern
447443 ConversionPatternRewriter &rewriter) const override {
448444 auto eTy = op.getTensorDescType ().getElementType ();
449445 auto eBw = eTy.getIntOrFloatBitWidth ();
450- if (eBw % 8 != 0 ) {
446+ if (eBw % 8 != 0 )
451447 return rewriter.notifyMatchFailure (
452448 op, " Expected element type bit width to be multiple of 8." );
453- }
454449 auto loc = op.getLoc ();
455450 // Scatter descriptor is provided as scalar i64 by type converter.
456451 // Offsets are provided as scalar i64 by type converter.
@@ -475,30 +470,27 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
475470 Value basePtrI64;
476471 // Load result or Store valye Type can be vector or scalar.
477472 Type valOrResTy;
478- if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
473+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
479474 valOrResTy = op.getResult ().getType ();
480- } else {
475+ else
481476 valOrResTy = adaptor.getValue ().getType ();
482- }
483477 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
484478 bool hasScalarVal = !valOrResVecTy;
485479 int64_t elemBitWidth =
486480 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth ()
487481 : valOrResVecTy.getElementType ().getIntOrFloatBitWidth ();
488482 // Element type must be multiple of 8 bits.
489- if (elemBitWidth % 8 != 0 ) {
483+ if (elemBitWidth % 8 != 0 )
490484 return rewriter.notifyMatchFailure (
491485 op, " Expected element type bit width to be multiple of 8." );
492- }
493486 int64_t elemByteSize = elemBitWidth / 8 ;
494487 // Default memory space is global.
495488 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get (
496489 ctxt, getNumericXeVMAddrSpace (xegpu::MemorySpace::Global));
497490 // If tensor descriptor is available, we use its memory space.
498- if (tdescTy) {
491+ if (tdescTy)
499492 ptrTypeLLVM = LLVM::LLVMPointerType::get (
500493 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
501- }
502494 // Base pointer can come from source (load) or dest (store).
503495 // If they are memrefs, we use their memory space.
504496 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
@@ -524,32 +516,30 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
524516 Value offsets = adaptor.getOffsets ();
525517 Value mask = adaptor.getMask ();
526518 if (offsets) {
527- if (dyn_cast<VectorType>(offsets.getType ())) {
519+ if (dyn_cast<VectorType>(offsets.getType ()))
528520 // Offset needs be scalar. Single element vector is converted to scalar
529521 // by type converter.
530522 return rewriter.notifyMatchFailure (op,
531523 " Expected offsets to be a scalar." );
532- } else {
524+ else
533525 // If offsets are provided, we add them to the base pointer.
534526 // Offsets are in number of elements, we need to multiply by
535527 // element byte size.
536528 basePtrI64 =
537529 addOffset (rewriter, loc, basePtrI64, offsets, elemByteSize);
538- }
539530 }
540531 // Convert base pointer (i64) to LLVM pointer type.
541532 Value basePtrLLVM =
542533 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtrI64);
543534
544535 Value maskForLane;
545536 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType ());
546- if (maskVecTy) {
537+ if (maskVecTy)
547538 // Mask needs be scalar. Single element vector is converted to scalar by
548539 // type converter.
549540 return rewriter.notifyMatchFailure (op, " Expected mask to be a scalar." );
550- } else {
541+ else
551542 maskForLane = mask;
552- }
553543 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
554544 scf::IfOp ifOp = scf::IfOp::create (rewriter, loc, {valOrResTy},
555545 maskForLane, true , true );
@@ -609,10 +599,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
609599 auto tdescTy = op.getTensorDescType ();
610600 Value basePtrI64 = adaptor.getSource ();
611601 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
612- if (basePtrI64.getType () != rewriter.getI64Type ()) {
602+ if (basePtrI64.getType () != rewriter.getI64Type ())
613603 basePtrI64 = arith::ExtUIOp::create (rewriter, loc, rewriter.getI64Type (),
614604 basePtrI64);
615- }
616605 Value offsets = adaptor.getOffsets ();
617606 if (offsets) {
618607 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType ());
@@ -637,10 +626,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
637626 elemByteSize = *op.getOffsetAlignByte ();
638627 }
639628 if (elemBitWidth != 0 ) {
640- if (elemBitWidth % 8 != 0 ) {
629+ if (elemBitWidth % 8 != 0 )
641630 return rewriter.notifyMatchFailure (
642631 op, " Expected element type bit width to be multiple of 8." );
643- }
644632 elemByteSize = elemBitWidth / 8 ;
645633 }
646634 basePtrI64 =
@@ -651,10 +639,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
651639 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get (
652640 ctxt, getNumericXeVMAddrSpace (xegpu::MemorySpace::Global));
653641 // If tensor descriptor is available, we use its memory space.
654- if (tdescTy) {
642+ if (tdescTy)
655643 ptrTypeLLVM = LLVM::LLVMPointerType::get (
656644 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
657- }
658645 // If source is a memref, we use its memory space.
659646 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource ().getType ())) {
660647 auto addrSpace = memRefTy.getMemorySpaceAsInt ();
@@ -883,9 +870,8 @@ struct ConvertXeGPUToXeVMPass
883870 return VectorType::get (sum, elemType);
884871 });
885872 typeConverter.addConversion ([&](xegpu::TensorDescType type) -> Type {
886- if (type.isScattered ()) {
873+ if (type.isScattered ())
887874 return IntegerType::get (&getContext (), 64 );
888- }
889875 auto i32Type = IntegerType::get (&getContext (), 32 );
890876 return VectorType::get (8 , i32Type);
891877 });
0 commit comments