2020#include  " mlir/Dialect/MemRef/IR/MemRef.h" 
2121#include  " mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 
2222#include  " mlir/Dialect/SCF/Transforms/Patterns.h" 
23+ #include  " mlir/IR/Builders.h" 
2324#include  " mlir/IR/BuiltinTypes.h" 
24- #include  " mlir/IR/ImplicitLocOpBuilder.h" 
2525#include  " mlir/IR/PatternMatch.h" 
2626#include  " mlir/IR/TypeUtilities.h" 
2727#include  " mlir/IR/Value.h" 
@@ -114,7 +114,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
114114
115115  auto  makeConst = [&](int32_t  index) -> Value {
116116    return  LLVM::ConstantOp::create (rewriter, loc, IntegerType::get (ctx, 32 ),
117-                                               rewriter.getI32IntegerAttr (index));
117+                                     rewriter.getI32IntegerAttr (index));
118118  };
119119
120120  if  (arrayType) {
@@ -147,11 +147,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
147147        Value x1 =
148148            LLVM::ExtractValueOp::create (rewriter, loc, intrinsicResult, i * 2 );
149149        Value x2 = LLVM::ExtractValueOp::create (rewriter, loc, intrinsicResult,
150-                                                           i * 2  + 1 );
150+                                                 i * 2  + 1 );
151151        vec = LLVM::InsertElementOp::create (rewriter, loc, vec.getType (), vec,
152-                                                       x1, makeConst (0 ));
152+                                             x1, makeConst (0 ));
153153        vec = LLVM::InsertElementOp::create (rewriter, loc, vec.getType (), vec,
154-                                                       x2, makeConst (1 ));
154+                                             x2, makeConst (1 ));
155155        elements.push_back (vec);
156156      }
157157    }
@@ -160,7 +160,7 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
160160    Value result = LLVM::PoisonOp::create (rewriter, loc, arrayType);
161161    for  (const  auto  &el : llvm::enumerate (elements)) {
162162      result = LLVM::InsertValueOp::create (rewriter, loc, result, el.value (),
163-                                                      el.index ());
163+                                            el.index ());
164164    }
165165    return  result;
166166  }
@@ -208,8 +208,8 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
208208                         innerArrayTy.getElementType () == f32Ty)) {
209209      for  (unsigned  idx = 0 , innerSize = innerArrayTy.getNumElements ();
210210           idx < innerSize; idx++) {
211-         result.push_back (LLVM::ExtractElementOp::create (b, 
212-             toUse,
211+         result.push_back (LLVM::ExtractElementOp::create (
212+             b,  toUse,
213213            LLVM::ConstantOp::create (b, i64Ty, b.getI64IntegerAttr (idx))));
214214      }
215215      continue ;
@@ -285,8 +285,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
285285    Value srcPtr =
286286        getStridedElementPtr (rewriter, b.getLoc (), srcMemrefType,
287287                             adaptor.getSrcMemref (), adaptor.getIndices ());
288-     Value ldMatrixResult = NVVM::LdMatrixOp::create (b, 
289-         ldMatrixResultType, srcPtr,
288+     Value ldMatrixResult = NVVM::LdMatrixOp::create (
289+         b,  ldMatrixResultType, srcPtr,
290290        /* num=*/  op.getNumTiles (),
291291        /* layout=*/  op.getTranspose () ? NVVM::MMALayout::col
292292                                     : NVVM::MMALayout::row);
@@ -375,16 +375,16 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
375375    Type desiredRetTy = typeConverter->convertType (op->getResultTypes ()[0 ]);
376376    Type intrinsicResTy = inferIntrinsicResultType (
377377        typeConverter->convertType (op->getResultTypes ()[0 ]));
378-     Value intrinsicResult =  NVVM::MmaOp::create (b, 
379-         intrinsicResTy, matA, matB, matC,
380-         /* shape=*/  gemmShape,
381-         /* b1Op=*/  std::nullopt ,
382-         /* intOverflow=*/  overflow,
383-         /* multiplicandPtxTypes=*/ 
384-         std::array<NVVM::MMATypes, 2 >{*ptxTypeA, *ptxTypeB},
385-         /* multiplicandLayouts=*/ 
386-         std::array<NVVM::MMALayout, 2 >{NVVM::MMALayout::row, 
387-                                         NVVM::MMALayout::col});
378+     Value intrinsicResult =
379+         NVVM::MmaOp::create (b,  intrinsicResTy, matA, matB, matC,
380+                              /* shape=*/  gemmShape,
381+                              /* b1Op=*/  std::nullopt ,
382+                              /* intOverflow=*/  overflow,
383+                              /* multiplicandPtxTypes=*/ 
384+                              std::array<NVVM::MMATypes, 2 >{*ptxTypeA, *ptxTypeB},
385+                              /* multiplicandLayouts=*/ 
386+                              std::array<NVVM::MMALayout, 2 >{
387+                                 NVVM::MMALayout::row,  NVVM::MMALayout::col});
388388    rewriter.replaceOp (op, convertIntrinsicResult (op.getLoc (), intrinsicResTy,
389389                                                  desiredRetTy, intrinsicResult,
390390                                                  rewriter));
@@ -566,14 +566,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
566566  asmVals.push_back (indexData);
567567
568568  return  LLVM::InlineAsmOp::create (b,
569-       /* resultTypes=*/  intrinsicResultType,
570-       /* operands=*/  asmVals,
571-       /* asm_string=*/  asmStr,
572-       /* constraints=*/  constraintStr,
573-       /* has_side_effects=*/ true ,
574-       /* is_align_stack=*/ false , LLVM::TailCallKind::None,
575-       /* asm_dialect=*/  asmDialectAttr,
576-       /* operand_attrs=*/ ArrayAttr ());
569+                                    /* resultTypes=*/  intrinsicResultType,
570+                                    /* operands=*/  asmVals,
571+                                    /* asm_string=*/  asmStr,
572+                                    /* constraints=*/  constraintStr,
573+                                    /* has_side_effects=*/ true ,
574+                                    /* is_align_stack=*/ false ,
575+                                    LLVM::TailCallKind::None,
576+                                    /* asm_dialect=*/  asmDialectAttr,
577+                                    /* operand_attrs=*/ ArrayAttr ());
577578}
578579
579580// / Lowers `nvgpu.mma.sp.sync` to inline assembly.
@@ -698,12 +699,12 @@ struct NVGPUAsyncCopyLowering
698699      //  filled with zeros.
699700      Value c3I32 =
700701          LLVM::ConstantOp::create (b, b.getI32Type (), b.getI32IntegerAttr (3 ));
701-       Value bitwidth = LLVM::ConstantOp::create (b, 
702-           b.getI32Type (),
702+       Value bitwidth = LLVM::ConstantOp::create (
703+           b, b .getI32Type (),
703704          b.getI32IntegerAttr (srcMemrefType.getElementTypeBitWidth ()));
704705      Value srcElementsI32 = LLVM::TruncOp::create (b, b.getI32Type (), srcBytes);
705-       srcBytes = LLVM::LShrOp::create (b, 
706-           LLVM::MulOp::create (b, bitwidth, srcElementsI32), c3I32);
706+       srcBytes = LLVM::LShrOp::create (
707+           b,  LLVM::MulOp::create (b, bitwidth, srcElementsI32), c3I32);
707708    }
708709    //  Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709710    //  16 dst bytes.
@@ -712,14 +713,15 @@ struct NVGPUAsyncCopyLowering
712713            ? NVVM::LoadCacheModifierKind::CG
713714            : NVVM::LoadCacheModifierKind::CA;
714715
715-     NVVM::CpAsyncOp::create (b, 
716-         dstPtr, scrPtr, rewriter.getI32IntegerAttr (sizeInBytes),
716+     NVVM::CpAsyncOp::create (
717+         b,  dstPtr, scrPtr, rewriter.getI32IntegerAttr (sizeInBytes),
717718        NVVM::LoadCacheModifierKindAttr::get (op->getContext (), cacheModifier),
718719        srcBytes);
719720
720721    //  Drop the result token.
721-     Value zero = LLVM::ConstantOp::create (b,
722-         IntegerType::get (op.getContext (), 32 ), rewriter.getI32IntegerAttr (0 ));
722+     Value zero =
723+         LLVM::ConstantOp::create (b, IntegerType::get (op.getContext (), 32 ),
724+                                  rewriter.getI32IntegerAttr (0 ));
723725    rewriter.replaceOp (op, zero);
724726    return  success ();
725727  }
@@ -735,9 +737,9 @@ struct NVGPUAsyncCreateGroupLowering
735737                  ConversionPatternRewriter &rewriter) const  override  {
736738    NVVM::CpAsyncCommitGroupOp::create (rewriter, op.getLoc ());
737739    //  Drop the result token.
738-     Value zero = LLVM::ConstantOp::create (rewriter,
739-         op-> getLoc (),  IntegerType::get (op.getContext (), 32 ),
740-         rewriter.getI32IntegerAttr (0 ));
740+     Value zero = LLVM::ConstantOp::create (rewriter, op-> getLoc (), 
741+                                            IntegerType::get (op.getContext (), 32 ),
742+                                            rewriter.getI32IntegerAttr (0 ));
741743    rewriter.replaceOp (op, zero);
742744    return  success ();
743745  }
@@ -771,8 +773,8 @@ struct NVGPUMBarrierCreateLowering
771773    SymbolTable symbolTable (moduleOp);
772774    OpBuilder::InsertionGuard guard (rewriter);
773775    rewriter.setInsertionPoint (&moduleOp.front ());
774-     auto  global = memref::GlobalOp::create (rewriter, 
775-         funcOp->getLoc (), " __mbarrier"  ,
776+     auto  global = memref::GlobalOp::create (
777+         rewriter,  funcOp->getLoc (), " __mbarrier"  ,
776778        /* sym_visibility=*/  rewriter.getStringAttr (" private"  ),
777779        /* type=*/  barrierType,
778780        /* initial_value=*/ ElementsAttr (),
@@ -1119,7 +1121,7 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
11191121
11201122static  Value makeI64Const (ImplicitLocOpBuilder &b, int32_t  index) {
11211123  return  LLVM::ConstantOp::create (b, b.getIntegerType (64 ),
1122-                                      b.getI32IntegerAttr (index));
1124+                                   b.getI32IntegerAttr (index));
11231125}
11241126
11251127// / Returns a Value that holds data type enum that is expected by CUDA driver.
@@ -1182,11 +1184,11 @@ struct NVGPUTmaCreateDescriptorOpLowering
11821184    auto  promotedOperands = getTypeConverter ()->promoteOperands (
11831185        b.getLoc (), op->getOperands (), adaptor.getOperands (), b);
11841186
1185-     Value boxArrayPtr = LLVM::AllocaOp::create (b, llvmPointerType, llvmInt64Type, 
1186-                                                   makeI64Const (b, 5 ));
1187+     Value boxArrayPtr = LLVM::AllocaOp::create (
1188+         b, llvmPointerType, llvmInt64Type,  makeI64Const (b, 5 ));
11871189    for  (auto  [index, value] : llvm::enumerate (adaptor.getBoxDimensions ())) {
11881190      Value gep = LLVM::GEPOp::create (b, llvmPointerType, llvmPointerType,
1189-                                          boxArrayPtr, makeI64Const (b, index));
1191+                                       boxArrayPtr, makeI64Const (b, index));
11901192      LLVM::StoreOp::create (b, value, gep);
11911193    }
11921194
@@ -1430,9 +1432,9 @@ struct NVGPUWarpgroupMmaOpLowering
14301432      auto  overflow = NVVM::MMAIntOverflowAttr::get (
14311433          op->getContext (), NVVM::MMAIntOverflow::wrapped);
14321434
1433-       return  NVVM::WgmmaMmaAsyncOp::create (b, 
1434-           matrixC.getType (), matrixC, descriptorA, descriptorB, shape, itypeA ,
1435-           itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1435+       return  NVVM::WgmmaMmaAsyncOp::create (
1436+           b,  matrixC.getType (), matrixC, descriptorA, descriptorB, shape,
1437+           itypeA,  itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
14361438          overflow);
14371439    }
14381440
@@ -1444,15 +1446,16 @@ struct NVGPUWarpgroupMmaOpLowering
14441446      //  Perform GEMM
14451447      SmallVector<Value> wgmmaResults;
14461448      for  (int  i = 0 ; i < iterationM; ++i) {
1447-         Value matrixC = LLVM::ExtractValueOp::create (b, adaptor.getMatrixC (), i);
1449+         Value matrixC =
1450+             LLVM::ExtractValueOp::create (b, adaptor.getMatrixC (), i);
14481451        for  (int  j = 0 ; j < iterationN; ++j)
14491452          for  (int  k = 0 ; k < iterationK; ++k)
14501453            matrixC = generateWgmma (i, j, k, matrixC);
14511454        wgmmaResults.push_back (matrixC);
14521455      }
14531456      for  (auto  [idx, matrix] : llvm::enumerate (wgmmaResults)) {
14541457        wgmmaResult = LLVM::InsertValueOp::create (b, wgmmaResult.getType (),
1455-                                                      wgmmaResult, matrix, idx);
1458+                                                   wgmmaResult, matrix, idx);
14561459      }
14571460      return  wgmmaResult;
14581461    }
@@ -1486,9 +1489,9 @@ struct NVGPUWarpgroupMmaOpLowering
14861489    // / (WgmmaGroupSyncAlignedOp) for group synchronization
14871490    // / (WgmmaWaitGroupSyncOp) after the instructions.
14881491    Value generateWarpgroupMma () {
1489-       NVVM::WgmmaFenceAlignedOp::create (b,  );
1492+       NVVM::WgmmaFenceAlignedOp::create (b);
14901493      Value wgmmaResult = generateWgmmaGroup ();
1491-       NVVM::WgmmaGroupSyncAlignedOp::create (b,  );
1494+       NVVM::WgmmaGroupSyncAlignedOp::create (b);
14921495      NVVM::WgmmaWaitGroupSyncOp::create (b, op.getWaitGroup ());
14931496      return  wgmmaResult;
14941497    }
@@ -1626,7 +1629,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
16261629    auto  stype = cast<LLVM::LLVMStructType>(matriDValue.getType ());
16271630    for  (auto  [idx, matrixD] : llvm::enumerate (stype.getBody ())) {
16281631      auto  structType = cast<LLVM::LLVMStructType>(matrixD);
1629-       Value innerStructValue = LLVM::ExtractValueOp::create (b, matriDValue, idx);
1632+       Value innerStructValue =
1633+           LLVM::ExtractValueOp::create (b, matriDValue, idx);
16301634      storeFragmentedMatrix (b, innerStructValue, op.getDstMemref (), offset);
16311635      offset += structType.getBody ().size ();
16321636    }
@@ -1656,15 +1660,15 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
16561660      auto  structType = cast<LLVM::LLVMStructType>(s);
16571661      Value structValue = LLVM::ExtractValueOp::create (b, packStruct, idx);
16581662      for  (unsigned  i = 0 ; i < structType.getBody ().size (); ++i) {
1659-         structValue = LLVM::InsertValueOp::create (b,
1660-             structType, structValue,  zero, ArrayRef<int64_t >({i}));
1663+         structValue = LLVM::InsertValueOp::create (b, structType, structValue, 
1664+                                                    zero, ArrayRef<int64_t >({i}));
16611665      }
16621666      innerStructs.push_back (structValue);
16631667    }
16641668    //  Pack the inner structs into a single struct
16651669    for  (auto  [idx, matrix] : llvm::enumerate (innerStructs)) {
16661670      packStruct = LLVM::InsertValueOp::create (b, packStruct.getType (),
1667-                                                   packStruct, matrix, idx);
1671+                                                packStruct, matrix, idx);
16681672    }
16691673    rewriter.replaceOp (op, packStruct);
16701674    return  success ();
0 commit comments