Skip to content

Commit c312060

Browse files
author
Kshitij
committed
tmp
1 parent 7361395 commit c312060

File tree

5 files changed

+189
-95
lines changed

5 files changed

+189
-95
lines changed

mlir/examples/toy/Ch2/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# For a better template to copy, see examples/standalone
22
add_subdirectory(include)
33

4+
add_embedded_lit_tests(EmbeddedLitTestsGen
5+
"include/toy/TestPasses.td"
6+
"/disk1/kjain/workspace/llvm-project/mlir/examples/toy/Ch2/jkshtj"
7+
)
8+
49
set(LLVM_LINK_COMPONENTS
510
Support
611
)
@@ -13,8 +18,9 @@ add_toy_chapter(toyc-ch2
1318

1419
DEPENDS
1520
ToyCh2OpsIncGen
16-
21+
EmbeddedLitTestsGen
1722
)
23+
1824
include_directories(include/)
1925
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
2026
target_link_libraries(toyc-ch2

mlir/include/mlir/IR/Testable.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,4 @@ class LitTest<string name, code snippet, list<string> run = [], list<string> che
3131
list<string> checkLines = check;
3232
}
3333

34-
// Base class for elements that can have auto-generated LIT tests
35-
class Testable {
36-
// List of LIT tests associated with this element
37-
list<LitTest> tests = [];
38-
}
39-
4034
#endif // TESTABLE

mlir/include/mlir/Pass/PassBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class Statistic<string varName, string statName, string desc> {
6464
// Pass
6565
//===----------------------------------------------------------------------===//
6666

67-
class PassBase<string passArg, string base> : Testable {
67+
class PassBase<string passArg, string base> {
6868
// The command line argument of the pass.
6969
string argument = passArg;
7070

mlir/test/mlir-tblgen/gen-lit-tests.td

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,58 @@
22

33
include "mlir/Pass/PassBase.td"
44
include "mlir/IR/Testable.td"
5+
include "mlir/IR/OpBase.td"
56

6-
def TestPassWithEmbeddedLitTests : Pass<"test-pass-with-embedded-lit-tests"> {
7-
let summary = "pass summary";
7+
def Test_Dialect : Dialect {
8+
let name = "test";
9+
let cppNamespace = "test";
10+
}
11+
12+
def TestOp : Op<Test_Dialect, "test_op"> {
13+
let summary = "test op with mlir_example code blocks";
814
let description = [{
9-
Pass description
15+
This operation demonstrates the mlir_example feature for ops.
16+
17+
Basic usage:
18+
```mlir_example
19+
func.func @foo(%arg0: i32) -> i32 {
20+
%0 = test.test_op %arg0 : i32
21+
return %0 : i32
22+
}
23+
```
24+
25+
And some more examples -
26+
27+
```mlir_example
28+
func.func @foo1(%arg1: i32) -> i32 {
29+
%0 = test.test_op %arg1 : i32
30+
return %0 : i32
31+
}
32+
```
1033
}];
11-
12-
let tests = [
13-
LitTest<
14-
"lit_test_file_1.mlir",
15-
[{
16-
func.func @test1() {
17-
return 42;
18-
}
19-
}],
20-
[
21-
"// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s",
22-
],
23-
[
24-
"// RANDOM-CHECK-LABEL: func.func @test1",
25-
]
26-
>,
27-
LitTest<
28-
"lit_test_file_2.mlir",
29-
[{
30-
func.func @test2() {
31-
return 42;
32-
}
33-
}],
34-
[
35-
"// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s",
36-
],
37-
[
38-
"// RANDOM-CHECK-LABEL: func.func @test2",
39-
]
40-
>,
41-
];
42-
}
4334

44-
// CHECK-LABEL: // Generated 2 LIT test files
45-
// CHECK: // Use the following files for LIT testing:
35+
let arguments = (ins I32:$input);
36+
let results = (outs I32:$output);
37+
}
4638

47-
// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir
48-
// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir ---
49-
// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
50-
// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests
51-
// CHECK: func.func @test1() {
52-
// CHECK: return 42;
39+
// CHECK: // File: generated_TestOp_example_0.mlir
40+
// CHECK: // --- BEGIN generated_TestOp_example_0.mlir ---
41+
// CHECK: // RUN: mlir-opt %s --verify-roundtrip
42+
// CHECK: // Generated from TableGen definition: TestOp
43+
// CHECK: func.func @foo(%arg0: i32) -> i32 {
44+
// CHECK: %0 = test.test_op %arg0 : i32
45+
// CHECK: return %0 : i32
5346
// CHECK: }
54-
// CHECK: // RANDOM-CHECK-LABEL: func.func @test1
55-
// CHECK: --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir ---
47+
// CHECK: // --- END generated_TestOp_example_0.mlir ---
5648

57-
// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir
58-
// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir ---
59-
// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
60-
// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests
61-
// CHECK: func.func @test2() {
62-
// CHECK: return 42;
49+
// CHECK: // File: generated_TestOp_example_1.mlir
50+
// CHECK: // --- BEGIN generated_TestOp_example_1.mlir ---
51+
// CHECK: // RUN: mlir-opt %s --verify-roundtrip
52+
// CHECK: // Generated from TableGen definition: TestOp
53+
// CHECK: func.func @bar(%arg0: i32, %arg1: i32) -> i32 {
54+
// CHECK: %0 = test.test_op %arg0 : i32
55+
// CHECK: %1 = test.test_op %arg1 : i32
56+
// CHECK: %2 = arith.addi %0, %1 : i32
57+
// CHECK: return %2 : i32
6358
// CHECK: }
64-
// CHECK: // RANDOM-CHECK-LABEL: func.func @test2
65-
// CHECK: // --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir ---
59+
// CHECK: // --- END generated_TestOp_example_1.mlir ---

mlir/tools/mlir-tblgen/LitTestGen.cpp

Lines changed: 134 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/Support/CommandLine.h"
2020
#include "llvm/Support/FormatVariadic.h"
2121
#include "llvm/Support/Path.h"
22+
#include "llvm/Support/Regex.h"
2223
#include "llvm/TableGen/Error.h"
2324
#include "llvm/TableGen/Record.h"
2425

@@ -41,36 +42,136 @@ static llvm::cl::opt<std::string>
4142
struct LitTest {
4243
std::string sourceDefName;
4344
std::string testFileName;
44-
std::string irSnippet;
45+
std::string irSnippet;
4546
llvm::SmallVector<std::string> runLines;
4647
llvm::SmallVector<std::string> checkLines;
4748
};
4849

50+
/// Extract code snippets with mlir_example tag from a description field.
51+
/// Returns a vector of code snippets found within ```mlir_example ... ``` blocks.
52+
static llvm::SmallVector<std::string> extractMlirExamples(llvm::StringRef description) {
53+
llvm::SmallVector<std::string> examples;
54+
55+
// Pattern to match ```mlir_example ... ``` code blocks
56+
// [^\n]* matches rest of line after mlir_example
57+
// \n matches the newline after the opening fence
58+
// (.+?) captures the code content (non-greedy)
59+
// ``` matches the closing fence
60+
llvm::Regex codeBlockRegex("```mlir_example(.+)```");
61+
62+
llvm::StringRef remaining = description;
63+
llvm::SmallVector<llvm::StringRef> matches;
64+
65+
while (codeBlockRegex.match(remaining, &matches)) {
66+
if (matches.size() >= 2) {
67+
// matches[1] contains the captured group (the code content)
68+
std::string code = matches[1].str();
69+
70+
llvm::errs() << "DEBUG: Extracted raw code:\n[" << code << "]\n";
71+
72+
// Remove leading/trailing whitespace and comment markers (# prefix)
73+
llvm::SmallVector<llvm::StringRef> lines;
74+
llvm::StringRef codeRef(code);
75+
codeRef.split(lines, '\n', -1, false);
76+
77+
std::string processedCode;
78+
for (llvm::StringRef line : lines) {
79+
line = line.ltrim();
80+
// Remove leading # comment markers if present
81+
if (line.starts_with("#")) {
82+
line = line.drop_front(1).ltrim();
83+
}
84+
if (!line.empty() || !processedCode.empty()) {
85+
processedCode += line.str() + "\n";
86+
}
87+
}
88+
89+
// // Remove trailing empty lines
90+
// while (!processedCode.empty() && processedCode.back() == '\n') {
91+
// size_t lastNewline = processedCode.find_last_not_of('\n');
92+
// if (lastNewline == std::string::npos) {
93+
// processedCode.clear();
94+
// break;
95+
// }
96+
// processedCode = processedCode.substr(0, lastNewline + 1) + "\n";
97+
// }
98+
99+
if (!processedCode.empty()) {
100+
examples.push_back(processedCode);
101+
}
102+
}
103+
104+
// Move past this match to find the next one
105+
size_t matchEnd = remaining.find("```", remaining.find("```mlir_example") + 15);
106+
if (matchEnd == llvm::StringRef::npos)
107+
break;
108+
remaining = remaining.substr(matchEnd + 3);
109+
}
110+
111+
return examples;
112+
}
113+
49114
static llvm::SmallVector<LitTest> extractTestsFromRecord(const llvm::Record *record,
50115
llvm::StringRef dialectName = "") {
51116
llvm::SmallVector<LitTest> tests;
52-
53-
// Check if the record has a tests field
117+
118+
// Try to extract mlir_example code blocks from the description field
119+
const llvm::RecordVal *descVal = record->getValue("description");
120+
if (descVal) {
121+
llvm::StringRef description = record->getValueAsString("description");
122+
llvm::errs() << "DEBUG: Record: " << record->getName() << "\n";
123+
llvm::errs() << "DEBUG: Description length: " << description.size() << "\n";
124+
llvm::errs() << "DEBUG: Description content:\n" << description << "\n";
125+
llvm::errs() << "DEBUG: ---\n";
126+
if (!description.empty()) {
127+
llvm::SmallVector<std::string> examples = extractMlirExamples(description);
128+
llvm::errs() << "DEBUG: Found " << examples.size() << " examples\n";
129+
130+
// Create a LitTest for each extracted example
131+
for (size_t i = 0; i < examples.size(); ++i) {
132+
std::string testFileName;
133+
if (examples.size() == 1) {
134+
testFileName = "example.mlir";
135+
} else {
136+
testFileName = formatv("example_{0}.mlir", i);
137+
}
138+
139+
// Generate default RUN line with --verify-roundtrip
140+
llvm::SmallVector<std::string> runLines;
141+
runLines.push_back("// RUN: mlir-opt %s --verify-roundtrip");
142+
143+
tests.push_back(LitTest {
144+
record->getName().str(),
145+
testFileName,
146+
examples[i],
147+
runLines,
148+
{} // No CHECK lines by default
149+
});
150+
}
151+
}
152+
}
153+
154+
// Fall back to checking for the old tests field for backward compatibility
54155
const llvm::RecordVal *testsVal = record->getValue("tests");
55156
if (!testsVal)
56157
return tests;
57-
58-
const llvm::ListInit *testsList =
158+
159+
const llvm::ListInit *testsList =
59160
llvm::dyn_cast_or_null<llvm::ListInit>(testsVal->getValue());
60161
if (!testsList)
61162
return tests;
62-
163+
63164
for (const llvm::Init *init : testsList->getElements()) {
64165
const llvm::DefInit *defInit = llvm::dyn_cast<llvm::DefInit>(init);
65166
if (!defInit)
66167
continue;
67-
168+
68169
const llvm::Record *testRec = defInit->getDef();
69-
170+
70171
// Extract fields from LitTest record
71172
std::string name = testRec->getValueAsString("testFileName").str();
72173
std::string irSnippet = testRec->getValueAsString("irSnippet").str();
73-
174+
74175
llvm::SmallVector<std::string> runLines;
75176
llvm::for_each(*testRec->getValueAsListInit("runLines"), [&](const llvm::Init *init) {
76177
runLines.emplace_back(llvm::cast<llvm::StringInit>(init)->getValue());
@@ -83,31 +184,31 @@ static llvm::SmallVector<LitTest> extractTestsFromRecord(const llvm::Record *rec
83184

84185
tests.push_back(LitTest {
85186
record->getName().str(),
86-
name,
87-
irSnippet,
88-
runLines,
89-
checkLines,
187+
name,
188+
irSnippet,
189+
runLines,
190+
checkLines,
90191
});
91192
}
92-
193+
93194
return tests;
94195
}
95196

96-
/// Extract tests from passes
97-
static llvm::SmallVector<LitTest> extractPassTests(const RecordKeeper &records) {
197+
/// Extract tests from ops
198+
static llvm::SmallVector<LitTest> extractOpTests(const RecordKeeper &records) {
98199
llvm::SmallVector<LitTest> tests;
99-
100-
// Check if PassBase class exists before trying to get derived definitions
101-
if (records.getClass("PassBase")) {
102-
for (const llvm::Record *def : records.getAllDerivedDefinitions("PassBase")) {
200+
201+
// Check if Op class exists before trying to get derived definitions
202+
if (records.getClass("Op")) {
203+
for (const llvm::Record *def : records.getAllDerivedDefinitions("Op")) {
103204
if (def->isAnonymous())
104205
continue;
105-
106-
auto passTests = extractTestsFromRecord(def, "passes");
107-
tests.insert(tests.end(), passTests.begin(), passTests.end());
206+
207+
auto opTests = extractTestsFromRecord(def, "ops");
208+
tests.insert(tests.end(), opTests.begin(), opTests.end());
108209
}
109210
}
110-
211+
111212
return tests;
112213
}
113214

@@ -132,26 +233,25 @@ static void generateTestFile(const LitTest &test, llvm::raw_ostream &os) {
132233
/// Main function to generate all IR test test files
133234
static void generateLitTests(const RecordKeeper &records, raw_ostream &os) {
134235
llvm::SmallVector<LitTest> allTests;
135-
136-
// Extract tests from different definition types (only passes for now)
137-
auto passTests = extractPassTests(records);
138-
139-
allTests.insert(allTests.end(), passTests.begin(), passTests.end());
140-
236+
237+
// Extract tests from different definition types
238+
auto opTests = extractOpTests(records);
239+
allTests.insert(allTests.end(), opTests.begin(), opTests.end());
240+
141241
if (allTests.empty()) {
142-
os << "// No LitTest record found in any TableGen definition\n";
242+
os << "// No mlir_example code blocks found in any TableGen definition\n";
143243
return;
144244
}
145-
245+
146246
// Generate summary
147247
os << "// Generated " << allTests.size() << " LIT test files\n";
148248
os << "// Use the following files for LIT testing:\n\n";
149-
249+
150250
// Generate file list and content for each test
151251
for (const auto& test : allTests) {
152252
std::string testFileName = formatv("generated_{0}_{1}", test.sourceDefName, test.testFileName);
153253
os << "// File: " << testFileName << "\n";
154-
254+
155255
os << "// --- BEGIN " << testFileName << " ---\n";
156256
generateTestFile(test, os);
157257
os << "// --- END " << testFileName << " ---\n\n";

0 commit comments

Comments
 (0)