1212// ===----------------------------------------------------------------------===//
1313
1414#include " mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
15+ #include " mlir/IR/BuiltinTypes.h"
16+ #include " mlir/IR/Diagnostics.h"
17+ #include " mlir/IR/Location.h"
18+ #include " mlir/IR/MLIRContext.h"
1519
20+ #include " mlir/Support/LLVM.h"
1621#include " llvm/ADT/StringExtras.h"
22+ #include " llvm/ADT/TypeSwitch.h"
1723#include " llvm/Support/DebugLog.h"
1824#include " llvm/Support/FormatVariadic.h"
25+ #include " llvm/Support/LogicalResult.h"
1926#include " llvm/Support/Regex.h"
2027
2128#define DEBUG_TYPE " ptx-builder"
@@ -31,35 +38,88 @@ using namespace NVVM;
3138
3239static constexpr int64_t kSharedMemorySpace = 3 ;
3340
34- static char getRegisterType (Type type) {
35- if (type.isInteger (1 ))
36- return ' b' ;
37- if (type.isInteger (16 ))
38- return ' h' ;
39- if (type.isInteger (32 ))
40- return ' r' ;
41- if (type.isInteger (64 ))
42- return ' l' ;
43- if (type.isF32 ())
44- return ' f' ;
45- if (type.isF64 ())
46- return ' d' ;
47- if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
48- // Shared address spaces is addressed with 32-bit pointers.
49- if (ptr.getAddressSpace () == kSharedMemorySpace ) {
41+ static FailureOr<char > getRegisterType (Type type, Location loc) {
42+ MLIRContext *ctx = type.getContext ();
43+ auto i16 = IntegerType::get (ctx, 16 );
44+ auto i32 = IntegerType::get (ctx, 32 );
45+ auto f32 = Float32Type::get (ctx);
46+
47+ auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char > {
48+ if (type.isInteger (1 ))
49+ return ' b' ;
50+ if (type.isInteger (16 ))
51+ return ' h' ;
52+ if (type.isInteger (32 ))
5053 return ' r' ;
54+ if (type.isInteger (64 ))
55+ return ' l' ;
56+ if (type.isF32 ())
57+ return ' f' ;
58+ if (type.isF64 ())
59+ return ' d' ;
60+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
61+ // Shared address spaces is addressed with 32-bit pointers.
62+ if (ptr.getAddressSpace () == kSharedMemorySpace ) {
63+ return ' r' ;
64+ }
65+ return ' l' ;
66+ }
67+ // register type for struct is not supported.
68+ mlir::emitError (
69+ loc, " The register type could not be deduced from MLIR type. The " )
70+ << type
71+ << " is not supported. Supported types are:"
72+ " i1, i16, i32, i64, f32, f64,"
73+ " pointers.\n Please use llvm.bitcast if you have different type. "
74+ " \n See the constraints from here: "
75+ " https://docs.nvidia.com/cuda/inline-ptx-assembly/"
76+ " index.html#constraints" ;
77+ return failure ();
78+ };
79+
80+ // Packed registers
81+ if (auto v = dyn_cast<VectorType>(type)) {
82+ assert (v.getNumDynamicDims () == 0 && " Dynamic vectors are not supported" );
83+
84+ int64_t lanes = v.getNumElements ();
85+ Type elem = v.getElementType ();
86+
87+ // Case 1. Single vector
88+ if (lanes <= 1 )
89+ return getRegisterTypeForScalar (elem);
90+
91+ // Case 2. Packed registers
92+ Type widened = elem;
93+ switch (lanes) {
94+
95+ case 2 :
96+ if (elem.isF16 () || elem.isBF16 ()) // vector<2xf16>
97+ widened = f32 ;
98+ else if (elem.isFloat (8 )) // vector<2xf8>
99+ widened = i16 ;
100+ break ;
101+ case 4 :
102+ if (elem.isInteger (8 )) // vector<i8x4>
103+ widened = i32 ;
104+ else if (elem.isFloat (8 )) // vector<f8x4>
105+ widened = f32 ;
106+ else if (elem.isFloat (4 )) // vector<f4x4>
107+ widened = i16 ;
108+ break ;
109+ // Other packing is not supported
110+ default :
111+ break ;
51112 }
52- return ' l ' ;
113+ return getRegisterTypeForScalar (widened) ;
53114 }
54- // register type for struct is not supported.
55- llvm_unreachable (" The register type could not deduced from MLIR type" );
56- return ' ?' ;
115+
116+ return getRegisterTypeForScalar (type);
57117}
58118
59- static char getRegisterType (Value v) {
119+ static FailureOr< char > getRegisterType (Value v, Location loc ) {
60120 if (v.getDefiningOp <LLVM::ConstantOp>())
61121 return ' n' ;
62- return getRegisterType (v.getType ());
122+ return getRegisterType (v.getType (), loc );
63123}
64124
65125// / Extract every element of a struct value.
@@ -75,10 +135,11 @@ static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
75135 return elems;
76136}
77137
78- void PtxBuilder::insertValue (Value v, PTXRegisterMod itype) {
138+ LogicalResult PtxBuilder::insertValue (Value v, PTXRegisterMod itype) {
79139 LDBG () << v << " \t Modifier : " << itype << " \n " ;
80140 registerModifiers.push_back (itype);
81141
142+ Location loc = interfaceOp->getLoc ();
82143 auto getModifier = [&]() -> const char * {
83144 switch (itype) {
84145 case PTXRegisterMod::Read:
@@ -111,21 +172,29 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
111172 }
112173 for (auto [idx, t] : llvm::enumerate (stype.getBody ())) {
113174 if (itype != PTXRegisterMod::Write) {
114- Value extractValue = LLVM::ExtractValueOp::create (
115- rewriter, interfaceOp-> getLoc () , v, idx);
175+ Value extractValue =
176+ LLVM::ExtractValueOp::create ( rewriter, loc , v, idx);
116177 addValue (extractValue);
117178 }
118179 if (itype == PTXRegisterMod::ReadWrite) {
119180 ss << idx << " ," ;
120181 } else {
121- ss << getModifier () << getRegisterType (t) << " ," ;
182+ FailureOr<char > regType = getRegisterType (t, loc);
183+ if (failed (regType))
184+ return rewriter.notifyMatchFailure (loc,
185+ " failed to get register type" );
186+ ss << getModifier () << regType.value () << " ," ;
122187 }
123188 }
124- return ;
189+ return success () ;
125190 }
126191 // Handle Scalars
127192 addValue (v);
128- ss << getModifier () << getRegisterType (v) << " ," ;
193+ FailureOr<char > regType = getRegisterType (v, loc);
194+ if (failed (regType))
195+ return rewriter.notifyMatchFailure (loc, " failed to get register type" );
196+ ss << getModifier () << regType.value () << " ," ;
197+ return success ();
129198}
130199
131200// / Check if the operation needs to pack and unpack results.
0 commit comments