@@ -64,15 +64,14 @@ static char getRegisterType(Value v) {
6464
6565// / Extract every element of a struct value.
6666static SmallVector<Value> extractStructElements (PatternRewriter &rewriter,
67- Location loc, Value agg) {
68- auto structTy = cast<LLVM::LLVMStructType>(agg.getType ());
67+ Location loc, Value structVal) {
68+ auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType ());
69+ assert (structTy && " expected LLVM struct" );
70+
6971 SmallVector<Value> elems;
70- elems.reserve (structTy.getBody ().size ());
71- for (auto [i, t] : llvm::enumerate (structTy.getBody ())) {
72- (void )t;
73- Value e = LLVM::ExtractValueOp::create (rewriter, loc, agg, i);
74- elems.push_back (e);
75- }
72+ for (unsigned i : llvm::seq<unsigned >(0 , structTy.getBody ().size ()))
73+ elems.push_back (rewriter.create <LLVM::ExtractValueOp>(loc, structVal, i));
74+
7675 return elems;
7776}
7877
@@ -81,15 +80,17 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
8180 registerModifiers.push_back (itype);
8281
8382 auto getModifier = [&]() -> const char * {
84- if (itype == PTXRegisterMod::ReadWrite) {
85- // "Read-Write modifier is not supported
86- // Interface canonicalize it later
87- return " +" ;
88- }
89- if (itype == PTXRegisterMod::Write) {
83+ switch (itype) {
84+ case PTXRegisterMod::Read:
85+ return " " ;
86+ case PTXRegisterMod::Write:
9087 return " =" ;
88+ case PTXRegisterMod::ReadWrite:
89+ // "Read-Write modifier is not actually supported
90+ // Interface will change it to "=" later and add integer mapping
91+ return " +" ;
9192 }
92- return " " ;
93+ llvm_unreachable ( " Unknown PTX register modifier " ) ;
9394 };
9495
9596 auto addValue = [&](Value v) {
@@ -134,12 +135,12 @@ needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
134135 SmallVectorImpl<PTXRegisterMod> ®isterModifiers) {
135136 if (needsManualRegisterMapping)
136137 return false ;
137- const unsigned writeOnly = interfaceOp->getNumResults ();
138- const unsigned readWrite =
138+ const unsigned writeOnlyVals = interfaceOp->getNumResults ();
139+ const unsigned readWriteVals =
139140 llvm::count_if (registerModifiers, [](PTXRegisterMod m) {
140141 return m == PTXRegisterMod::ReadWrite;
141142 });
142- return (writeOnly + readWrite ) > 1 ;
143+ return (writeOnlyVals + readWriteVals ) > 1 ;
143144}
144145
145146// / Pack the result types of the interface operation.
@@ -219,14 +220,58 @@ static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
219220 return os.str ();
220221}
221222
222- constexpr llvm::StringLiteral kReadWrite {" rw" };
223- constexpr llvm::StringLiteral kWriteOnly {" w" };
224- constexpr llvm::StringLiteral kReadOnly {" r" };
223+ constexpr llvm::StringLiteral kReadWritePrefix {" rw" };
224+ constexpr llvm::StringLiteral kWriteOnlyPrefix {" w" };
225+ constexpr llvm::StringLiteral kReadOnlyPrefix {" r" };
225226
226- // / Rewrites placeholders of the form `{$rN}`, `{$wN}`, `{$rwN}` in `asmText`
227- // / to compact `$K` indices where all `rw*` come first (ascending N), then `w*`,
228- // / then `r*`. Duplicates are de-duplicated when assigning numbers.
229- // / Unknown text is preserved verbatim.
227+ // / Returns a regex that matches {$rwN}, {$wN}, {$rN}
228+ static llvm::Regex getPredicateMappingRegex () {
229+ llvm::Regex rx (llvm::formatv (R"( \{\$({0}|{1}|{2})([0-9]+)\})" ,
230+ kReadWritePrefix , kWriteOnlyPrefix ,
231+ kReadOnlyPrefix )
232+ .str ());
233+ return rx;
234+ }
235+
236+ void mlir::NVVM::countPlaceholderNumbers (
237+ StringRef ptxCode, llvm::SmallDenseSet<unsigned int > &seenRW,
238+ llvm::SmallDenseSet<unsigned int > &seenW,
239+ llvm::SmallDenseSet<unsigned int > &seenR,
240+ llvm::SmallVectorImpl<unsigned int > &rwNums,
241+ llvm::SmallVectorImpl<unsigned int > &wNums,
242+ llvm::SmallVectorImpl<unsigned int > &rNums) {
243+
244+ llvm::Regex rx = getPredicateMappingRegex ();
245+ StringRef rest = ptxCode;
246+
247+ SmallVector<StringRef, 3 > m; // 0: full, 1: kind, 2: number
248+ while (!rest.empty () && rx.match (rest, &m)) {
249+ unsigned num = 0 ;
250+ (void )m[2 ].getAsInteger (10 , num);
251+
252+ if (m[1 ].equals_insensitive (kReadWritePrefix )) {
253+ if (seenRW.insert (num).second )
254+ rwNums.push_back (num);
255+ } else if (m[1 ].equals_insensitive (kWriteOnlyPrefix )) {
256+ if (seenW.insert (num).second )
257+ wNums.push_back (num);
258+ } else {
259+ if (seenR.insert (num).second )
260+ rNums.push_back (num);
261+ }
262+
263+ const size_t advance = (size_t )(m[0 ].data () - rest.data ()) + m[0 ].size ();
264+ rest = rest.drop_front (advance);
265+ }
266+ }
267+
268+ // / Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
269+ // / compact `$K` indices:
270+ // / - All `rw*` first (sorted by N),
271+ // / - Then `w*`,
272+ // / - Then `r*`.
273+ // / If there a predicate, it comes always in the end.
274+ // / Each number is assigned once; duplicates are ignored.
230275// /
231276// / Example Input:
232277// / "{
@@ -246,42 +291,19 @@ constexpr llvm::StringLiteral kReadOnly{"r"};
246291// / selp.s32 $2, $4, $5, p;
247292// / selp.s32 $3, $4, $5, p;
248293// / }\n"
249- static std::string rewriteAsmPlaceholders (llvm::StringRef asmText) {
250- // Match {$rwN}, {$wN}, {$rN}
251- llvm::Regex rx (llvm::formatv (R"( \{\$({0}|{1}|{2})([0-9]+)\})" , kReadWrite ,
252- kWriteOnly , kReadOnly )
253- .str ());
254-
294+ static std::string rewriteAsmPlaceholders (llvm::StringRef ptxCode) {
255295 llvm::SmallDenseSet<unsigned > seenRW, seenW, seenR;
256296 llvm::SmallVector<unsigned > rwNums, wNums, rNums;
257297
258- {
259- StringRef rest = asmText;
260- SmallVector<StringRef, 3 > m; // 0: full, 1: kind, 2: number
261- while (!rest.empty () && rx.match (rest, &m)) {
262- unsigned num = 0 ;
263- (void )m[2 ].getAsInteger (10 , num);
264-
265- if (m[1 ].equals_insensitive (kReadWrite )) {
266- if (seenRW.insert (num).second )
267- rwNums.push_back (num);
268- } else if (m[1 ].equals_insensitive (kWriteOnly )) {
269- if (seenW.insert (num).second )
270- wNums.push_back (num);
271- } else {
272- if (seenR.insert (num).second )
273- rNums.push_back (num);
274- }
275-
276- const size_t advance = (size_t )(m[0 ].data () - rest.data ()) + m[0 ].size ();
277- rest = rest.drop_front (advance);
278- }
279- }
298+ // Step 1. Count Register Placeholder numbers
299+ countPlaceholderNumbers (ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
280300
301+ // Step 2. Sort the Register Placeholder numbers
281302 llvm::sort (rwNums);
282303 llvm::sort (wNums);
283304 llvm::sort (rNums);
284305
306+ // Step 3. Create mapping from original to new IDs
285307 llvm::DenseMap<unsigned , unsigned > rwMap, wMap, rMap;
286308 unsigned nextId = 0 ;
287309 for (unsigned n : rwNums)
@@ -291,27 +313,28 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
291313 for (unsigned n : rNums)
292314 rMap[n] = nextId++;
293315
316+ // Step 4. Rewrite the PTX code with new IDs
294317 std::string out;
295- out.reserve (asmText.size ());
296-
318+ out.reserve (ptxCode.size ());
297319 size_t prev = 0 ;
298- StringRef rest = asmText ;
320+ StringRef rest = ptxCode ;
299321 SmallVector<StringRef, 3 > m;
322+ llvm::Regex rx = getPredicateMappingRegex ();
300323 while (!rest.empty () && rx.match (rest, &m)) {
301324 // Compute absolute match bounds in the original buffer.
302- size_t absStart = (size_t )(m[0 ].data () - asmText .data ());
325+ size_t absStart = (size_t )(m[0 ].data () - ptxCode .data ());
303326 size_t absEnd = absStart + m[0 ].size ();
304327
305328 // Emit text before the match.
306- out.append (asmText .data () + prev, asmText .data () + absStart);
329+ out.append (ptxCode .data () + prev, ptxCode .data () + absStart);
307330
308331 // Emit compact $K
309332 unsigned num = 0 ;
310333 (void )m[2 ].getAsInteger (10 , num);
311334 unsigned id = 0 ;
312- if (m[1 ].equals_insensitive (kReadWrite ))
335+ if (m[1 ].equals_insensitive (kReadWritePrefix ))
313336 id = rwMap.lookup (num);
314- else if (m[1 ].equals_insensitive (kWriteOnly ))
337+ else if (m[1 ].equals_insensitive (kWriteOnlyPrefix ))
315338 id = wMap.lookup (num);
316339 else
317340 id = rMap.lookup (num);
@@ -321,13 +344,12 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
321344
322345 prev = absEnd;
323346
324- // Advance search window.
325347 const size_t advance = (size_t )(m[0 ].data () - rest.data ()) + m[0 ].size ();
326348 rest = rest.drop_front (advance);
327349 }
328350
329- // Tail.
330- out.append (asmText .data () + prev, asmText .data () + asmText .size ());
351+ // Step 5. Tail.
352+ out.append (ptxCode .data () + prev, ptxCode .data () + ptxCode .size ());
331353 return out;
332354}
333355
0 commit comments