2828#include  " mlir/Target/LLVMIR/TypeToLLVM.h" 
2929
3030#include  " llvm/ADT/StringRef.h" 
31+ #include  " llvm/ADT/TypeSwitch.h" 
32+ #include  " llvm/ADT/identity.h" 
3133#include  " llvm/IR/Attributes.h" 
3234#include  " llvm/Support/ErrorHandling.h" 
3335#include  " llvm/Support/ModRef.h" 
@@ -935,69 +937,77 @@ struct TritonMatrix2DBlockPrefetchLowering
935937};
936938
937939template  <typename  OpType, typename  = std::enable_if_t <llvm::is_one_of<
938-                                OpType, TritonGEN::SIMDBlockReadOp ,
939-                                TritonGEN::SIMDBlockWriteOp >::value>>
940- static  std::string getSIMDBlockManglingName (OpType op, VectorType vecTy ) {
940+                                OpType, TritonGEN::SubGroupBlockReadOp ,
941+                                TritonGEN::SubGroupBlockWriteOp >::value>>
942+ static  std::string getSubGroupBlockManglingName (OpType op, Type type ) {
941943  constexpr  bool  isWrite =
942-       std::is_same<OpType, TritonGEN::SIMDBlockWriteOp >::value;
944+       std::is_same<OpType, TritonGEN::SubGroupBlockWriteOp >::value;
943945  const  LLVM::LLVMPointerType ptrTy = op.getPtr ().getType ();
944-   const  unsigned  numElems = vecTy.getNumElements ();
945946  //  Note: OCL builtin name here differs from regular mangling.
946947  std::string funcName = " intel_sub_group_block_"  ;
947948  if  constexpr  (isWrite)
948949    funcName += " write"  ;
949950  else 
950951    funcName += " read"  ;
951-   funcName += " _u"   + intel::getTypeMangling (vecTy.getElementType ()) +
952-               (numElems == 1  ? " "   : std::to_string (numElems));
953-   funcName =
954-       " _Z"   + std::to_string (funcName.size ()) + funcName + " PU3AS"   +
955-       std::to_string (ptrTy.getAddressSpace ()) +
956-       intel::getTypeMangling (vecTy.getElementType (), /* isUnsigned=*/ true );
952+   Type elementType =
953+       TypeSwitch<Type, Type>(type)
954+           .Case ([](VectorType vecType) { return  vecType.getElementType (); })
955+           //  Scalar case
956+           .Default (llvm::identity<Type>());
957+   const  unsigned  numElems =
958+       TypeSwitch<Type, unsigned >(type)
959+           .Case ([](VectorType vecType) { return  vecType.getNumElements (); })
960+           //  Scalar case
961+           .Default (0u );
962+   funcName += " _u"   + intel::getTypeMangling (elementType) +
963+               (numElems ? std::to_string (numElems) : " "  );
964+   funcName = " _Z"   + std::to_string (funcName.size ()) + funcName + " PU3AS"   +
965+              std::to_string (ptrTy.getAddressSpace ()) +
966+              intel::getTypeMangling (elementType, /* isUnsigned=*/ true );
957967  if  constexpr  (isWrite)
958-     funcName += intel::getTypeMangling (vecTy , /* isUnsigned=*/ true );
968+     funcName += intel::getTypeMangling (type , /* isUnsigned=*/ true );
959969  return  funcName;
960970}
961971
962- struct  TritonSIMDBlockReadLowering 
963-     : public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockReadOp > {
972+ struct  TritonSubGroupBlockReadLowering 
973+     : public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockReadOp > {
964974  using  ConvertOpToLLVMPattern<
965-       TritonGEN::SIMDBlockReadOp >::ConvertOpToLLVMPattern;
975+       TritonGEN::SubGroupBlockReadOp >::ConvertOpToLLVMPattern;
966976
967977  LogicalResult
968-   matchAndRewrite (TritonGEN::SIMDBlockReadOp  op, OpAdaptor adaptor,
978+   matchAndRewrite (TritonGEN::SubGroupBlockReadOp  op, OpAdaptor adaptor,
969979                  ConversionPatternRewriter &rewriter) const  override  {
970980    LLVM::LLVMPointerType ptrTy = op.getPtr ().getType ();
971-     VectorType vecTy  = op.getRes ().getType ();
981+     Type type  = op.getRes ().getType ();
972982
973-     std::string funcName = getSIMDBlockManglingName (op, vecTy );
983+     std::string funcName = getSubGroupBlockManglingName (op, type );
974984    auto  memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
975985        /* other=*/  LLVM::ModRefInfo::NoModRef,
976986        /* argMem=*/  LLVM::ModRefInfo::Ref,
977987        /* inaccessibleMem=*/  LLVM::ModRefInfo::NoModRef);
978988    auto  funcAttrs = noUnwindWillReturnAttrs;
979989    funcAttrs.memEffectsAttr  = memAttr;
980990    LLVM::CallOp call = createDeviceFunctionCall (
981-         rewriter, funcName, vecTy , {ptrTy}, {op.getPtr ()}, {}, funcAttrs, {});
991+         rewriter, funcName, type , {ptrTy}, {op.getPtr ()}, {}, funcAttrs, {});
982992
983993    rewriter.replaceOp (op, call.getResult ());
984994    return  success ();
985995  }
986996};
987997
988- struct  TritonSIMDBlockWriteLowering 
989-     : public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockWriteOp > {
998+ struct  TritonSubGroupBlockWriteLowering 
999+     : public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockWriteOp > {
9901000  using  ConvertOpToLLVMPattern<
991-       TritonGEN::SIMDBlockWriteOp >::ConvertOpToLLVMPattern;
1001+       TritonGEN::SubGroupBlockWriteOp >::ConvertOpToLLVMPattern;
9921002
9931003  LogicalResult
994-   matchAndRewrite (TritonGEN::SIMDBlockWriteOp  op, OpAdaptor adaptor,
1004+   matchAndRewrite (TritonGEN::SubGroupBlockWriteOp  op, OpAdaptor adaptor,
9951005                  ConversionPatternRewriter &rewriter) const  override  {
9961006    MLIRContext *ctx = rewriter.getContext ();
9971007    LLVM::LLVMPointerType ptrTy = op.getPtr ().getType ();
998-     VectorType vecTy  = op.getVal ().getType ();
1008+     Type type  = op.getVal ().getType ();
9991009
1000-     std::string funcName = getSIMDBlockManglingName (op, vecTy );
1010+     std::string funcName = getSubGroupBlockManglingName (op, type );
10011011
10021012    auto  memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
10031013        /* other=*/  LLVM::ModRefInfo::NoModRef,
@@ -1006,7 +1016,7 @@ struct TritonSIMDBlockWriteLowering
10061016    auto  funcAttrs = noUnwindWillReturnAttrs;
10071017    funcAttrs.memEffectsAttr  = memAttr;
10081018    LLVM::CallOp call = createDeviceFunctionCall (
1009-         rewriter, funcName, void_ty (ctx), {ptrTy, vecTy },
1019+         rewriter, funcName, void_ty (ctx), {ptrTy, type },
10101020        {op.getPtr (), op.getVal ()}, {}, funcAttrs);
10111021
10121022    rewriter.replaceOp (op, call);
@@ -1071,12 +1081,13 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
10711081
10721082void  mlir::triton::populateTritonGENToLLVMConversionPatterns (
10731083    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1074-   patterns.add <
1075-       TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
1076-       TritonSubGroupReduceLowering, TritonSubGroupScanLowering,
1077-       TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
1078-       TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering,
1079-       TritonSIMDBlockReadLowering, TritonSIMDBlockWriteLowering>(converter);
1084+   patterns
1085+       .add <TritonGENSplitBarrierSignalLowering,
1086+            TritonGENSplitBarrierWaitLowering, TritonSubGroupReduceLowering,
1087+            TritonSubGroupScanLowering, TritonMatrixDPASLowering,
1088+            TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
1089+            TritonMatrix2DBlockPrefetchLowering, TritonSubGroupBlockReadLowering,
1090+            TritonSubGroupBlockWriteLowering>(converter);
10801091}
10811092
10821093void  registerConvertTritonTritonGENToLLVMInterface (DialectRegistry ®istry) {
0 commit comments