Skip to content

Commit 87d328b

Browse files
committed
address comments, make code more readable
1 parent 377e536 commit 87d328b

File tree

2 files changed

+95
-63
lines changed

2 files changed

+95
-63
lines changed

mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ class PtxBuilder {
9393
void buildAndReplaceOp();
9494
};
9595

96+
/// Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the
97+
/// PTX code.
98+
void countPlaceholderNumbers(StringRef ptxCode,
99+
llvm::SmallDenseSet<unsigned> &seenRW,
100+
llvm::SmallDenseSet<unsigned> &seenW,
101+
llvm::SmallDenseSet<unsigned> &seenR,
102+
llvm::SmallVectorImpl<unsigned> &rwNums,
103+
llvm::SmallVectorImpl<unsigned> &wNums,
104+
llvm::SmallVectorImpl<unsigned> &rNums);
105+
96106
} // namespace NVVM
97107
} // namespace mlir
98108

mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Lines changed: 85 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,14 @@ static char getRegisterType(Value v) {
6464

6565
/// Extract every element of a struct value.
6666
static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
67-
Location loc, Value agg) {
68-
auto structTy = cast<LLVM::LLVMStructType>(agg.getType());
67+
Location loc, Value structVal) {
68+
auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType());
69+
assert(structTy && "expected LLVM struct");
70+
6971
SmallVector<Value> elems;
70-
elems.reserve(structTy.getBody().size());
71-
for (auto [i, t] : llvm::enumerate(structTy.getBody())) {
72-
(void)t;
73-
Value e = LLVM::ExtractValueOp::create(rewriter, loc, agg, i);
74-
elems.push_back(e);
75-
}
72+
for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
73+
elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i));
74+
7675
return elems;
7776
}
7877

@@ -81,15 +80,17 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
8180
registerModifiers.push_back(itype);
8281

8382
auto getModifier = [&]() -> const char * {
84-
if (itype == PTXRegisterMod::ReadWrite) {
85-
// "Read-Write modifier is not supported
86-
// Interface canonicalize it later
87-
return "+";
88-
}
89-
if (itype == PTXRegisterMod::Write) {
83+
switch (itype) {
84+
case PTXRegisterMod::Read:
85+
return "";
86+
case PTXRegisterMod::Write:
9087
return "=";
88+
case PTXRegisterMod::ReadWrite:
89+
// "Read-Write modifier is not actually supported
90+
// Interface will change it to "=" later and add integer mapping
91+
return "+";
9192
}
92-
return "";
93+
llvm_unreachable("Unknown PTX register modifier");
9394
};
9495

9596
auto addValue = [&](Value v) {
@@ -134,12 +135,12 @@ needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
134135
SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
135136
if (needsManualRegisterMapping)
136137
return false;
137-
const unsigned writeOnly = interfaceOp->getNumResults();
138-
const unsigned readWrite =
138+
const unsigned writeOnlyVals = interfaceOp->getNumResults();
139+
const unsigned readWriteVals =
139140
llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
140141
return m == PTXRegisterMod::ReadWrite;
141142
});
142-
return (writeOnly + readWrite) > 1;
143+
return (writeOnlyVals + readWriteVals) > 1;
143144
}
144145

145146
/// Pack the result types of the interface operation.
@@ -219,14 +220,58 @@ static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
219220
return os.str();
220221
}
221222

222-
constexpr llvm::StringLiteral kReadWrite{"rw"};
223-
constexpr llvm::StringLiteral kWriteOnly{"w"};
224-
constexpr llvm::StringLiteral kReadOnly{"r"};
223+
constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
224+
constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
225+
constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
225226

226-
/// Rewrites placeholders of the form `{$rN}`, `{$wN}`, `{$rwN}` in `asmText`
227-
/// to compact `$K` indices where all `rw*` come first (ascending N), then `w*`,
228-
/// then `r*`. Duplicates are de-duplicated when assigning numbers.
229-
/// Unknown text is preserved verbatim.
227+
/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
228+
static llvm::Regex getPredicateMappingRegex() {
229+
llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
230+
kReadWritePrefix, kWriteOnlyPrefix,
231+
kReadOnlyPrefix)
232+
.str());
233+
return rx;
234+
}
235+
236+
void mlir::NVVM::countPlaceholderNumbers(
237+
StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
238+
llvm::SmallDenseSet<unsigned int> &seenW,
239+
llvm::SmallDenseSet<unsigned int> &seenR,
240+
llvm::SmallVectorImpl<unsigned int> &rwNums,
241+
llvm::SmallVectorImpl<unsigned int> &wNums,
242+
llvm::SmallVectorImpl<unsigned int> &rNums) {
243+
244+
llvm::Regex rx = getPredicateMappingRegex();
245+
StringRef rest = ptxCode;
246+
247+
SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
248+
while (!rest.empty() && rx.match(rest, &m)) {
249+
unsigned num = 0;
250+
(void)m[2].getAsInteger(10, num);
251+
252+
if (m[1].equals_insensitive(kReadWritePrefix)) {
253+
if (seenRW.insert(num).second)
254+
rwNums.push_back(num);
255+
} else if (m[1].equals_insensitive(kWriteOnlyPrefix)) {
256+
if (seenW.insert(num).second)
257+
wNums.push_back(num);
258+
} else {
259+
if (seenR.insert(num).second)
260+
rNums.push_back(num);
261+
}
262+
263+
const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
264+
rest = rest.drop_front(advance);
265+
}
266+
}
267+
268+
/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
269+
/// compact `$K` indices:
270+
/// - All `rw*` first (sorted by N),
271+
/// - Then `w*`,
272+
/// - Then `r*`.
273+
/// If there a predicate, it comes always in the end.
274+
/// Each number is assigned once; duplicates are ignored.
230275
///
231276
/// Example Input:
232277
/// "{
@@ -246,42 +291,19 @@ constexpr llvm::StringLiteral kReadOnly{"r"};
246291
/// selp.s32 $2, $4, $5, p;
247292
/// selp.s32 $3, $4, $5, p;
248293
/// }\n"
249-
static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
250-
// Match {$rwN}, {$wN}, {$rN}
251-
llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", kReadWrite,
252-
kWriteOnly, kReadOnly)
253-
.str());
254-
294+
static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
255295
llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
256296
llvm::SmallVector<unsigned> rwNums, wNums, rNums;
257297

258-
{
259-
StringRef rest = asmText;
260-
SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
261-
while (!rest.empty() && rx.match(rest, &m)) {
262-
unsigned num = 0;
263-
(void)m[2].getAsInteger(10, num);
264-
265-
if (m[1].equals_insensitive(kReadWrite)) {
266-
if (seenRW.insert(num).second)
267-
rwNums.push_back(num);
268-
} else if (m[1].equals_insensitive(kWriteOnly)) {
269-
if (seenW.insert(num).second)
270-
wNums.push_back(num);
271-
} else {
272-
if (seenR.insert(num).second)
273-
rNums.push_back(num);
274-
}
275-
276-
const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
277-
rest = rest.drop_front(advance);
278-
}
279-
}
298+
// Step 1. Count Register Placeholder numbers
299+
countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
280300

301+
// Step 2. Sort the Register Placeholder numbers
281302
llvm::sort(rwNums);
282303
llvm::sort(wNums);
283304
llvm::sort(rNums);
284305

306+
// Step 3. Create mapping from original to new IDs
285307
llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
286308
unsigned nextId = 0;
287309
for (unsigned n : rwNums)
@@ -291,27 +313,28 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
291313
for (unsigned n : rNums)
292314
rMap[n] = nextId++;
293315

316+
// Step 4. Rewrite the PTX code with new IDs
294317
std::string out;
295-
out.reserve(asmText.size());
296-
318+
out.reserve(ptxCode.size());
297319
size_t prev = 0;
298-
StringRef rest = asmText;
320+
StringRef rest = ptxCode;
299321
SmallVector<StringRef, 3> m;
322+
llvm::Regex rx = getPredicateMappingRegex();
300323
while (!rest.empty() && rx.match(rest, &m)) {
301324
// Compute absolute match bounds in the original buffer.
302-
size_t absStart = (size_t)(m[0].data() - asmText.data());
325+
size_t absStart = (size_t)(m[0].data() - ptxCode.data());
303326
size_t absEnd = absStart + m[0].size();
304327

305328
// Emit text before the match.
306-
out.append(asmText.data() + prev, asmText.data() + absStart);
329+
out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
307330

308331
// Emit compact $K
309332
unsigned num = 0;
310333
(void)m[2].getAsInteger(10, num);
311334
unsigned id = 0;
312-
if (m[1].equals_insensitive(kReadWrite))
335+
if (m[1].equals_insensitive(kReadWritePrefix))
313336
id = rwMap.lookup(num);
314-
else if (m[1].equals_insensitive(kWriteOnly))
337+
else if (m[1].equals_insensitive(kWriteOnlyPrefix))
315338
id = wMap.lookup(num);
316339
else
317340
id = rMap.lookup(num);
@@ -321,13 +344,12 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
321344

322345
prev = absEnd;
323346

324-
// Advance search window.
325347
const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
326348
rest = rest.drop_front(advance);
327349
}
328350

329-
// Tail.
330-
out.append(asmText.data() + prev, asmText.data() + asmText.size());
351+
// Step 5. Tail.
352+
out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
331353
return out;
332354
}
333355

0 commit comments

Comments
 (0)