1919#include " llvm/Support/CommandLine.h"
2020#include " llvm/Support/FormatVariadic.h"
2121#include " llvm/Support/Path.h"
22+ #include " llvm/Support/Regex.h"
2223#include " llvm/TableGen/Error.h"
2324#include " llvm/TableGen/Record.h"
2425
@@ -41,36 +42,136 @@ static llvm::cl::opt<std::string>
4142struct LitTest {
4243 std::string sourceDefName;
4344 std::string testFileName;
44- std::string irSnippet;
45+ std::string irSnippet;
4546 llvm::SmallVector<std::string> runLines;
4647 llvm::SmallVector<std::string> checkLines;
4748};
4849
50+ // / Extract code snippets with mlir_example tag from a description field.
51+ // / Returns a vector of code snippets found within ```mlir_example ... ``` blocks.
52+ static llvm::SmallVector<std::string> extractMlirExamples (llvm::StringRef description) {
53+ llvm::SmallVector<std::string> examples;
54+
55+ // Pattern to match ```mlir_example ... ``` code blocks
56+ // [^\n]* matches rest of line after mlir_example
57+ // \n matches the newline after the opening fence
58+ // (.+?) captures the code content (non-greedy)
59+ // ``` matches the closing fence
60+ llvm::Regex codeBlockRegex (" ```mlir_example(.+)```" );
61+
62+ llvm::StringRef remaining = description;
63+ llvm::SmallVector<llvm::StringRef> matches;
64+
65+ while (codeBlockRegex.match (remaining, &matches)) {
66+ if (matches.size () >= 2 ) {
67+ // matches[1] contains the captured group (the code content)
68+ std::string code = matches[1 ].str ();
69+
70+ llvm::errs () << " DEBUG: Extracted raw code:\n [" << code << " ]\n " ;
71+
72+ // Remove leading/trailing whitespace and comment markers (# prefix)
73+ llvm::SmallVector<llvm::StringRef> lines;
74+ llvm::StringRef codeRef (code);
75+ codeRef.split (lines, ' \n ' , -1 , false );
76+
77+ std::string processedCode;
78+ for (llvm::StringRef line : lines) {
79+ line = line.ltrim ();
80+ // Remove leading # comment markers if present
81+ if (line.starts_with (" #" )) {
82+ line = line.drop_front (1 ).ltrim ();
83+ }
84+ if (!line.empty () || !processedCode.empty ()) {
85+ processedCode += line.str () + " \n " ;
86+ }
87+ }
88+
89+ // // Remove trailing empty lines
90+ // while (!processedCode.empty() && processedCode.back() == '\n') {
91+ // size_t lastNewline = processedCode.find_last_not_of('\n');
92+ // if (lastNewline == std::string::npos) {
93+ // processedCode.clear();
94+ // break;
95+ // }
96+ // processedCode = processedCode.substr(0, lastNewline + 1) + "\n";
97+ // }
98+
99+ if (!processedCode.empty ()) {
100+ examples.push_back (processedCode);
101+ }
102+ }
103+
104+ // Move past this match to find the next one
105+ size_t matchEnd = remaining.find (" ```" , remaining.find (" ```mlir_example" ) + 15 );
106+ if (matchEnd == llvm::StringRef::npos)
107+ break ;
108+ remaining = remaining.substr (matchEnd + 3 );
109+ }
110+
111+ return examples;
112+ }
113+
49114static llvm::SmallVector<LitTest> extractTestsFromRecord (const llvm::Record *record,
50115 llvm::StringRef dialectName = " " ) {
51116 llvm::SmallVector<LitTest> tests;
52-
53- // Check if the record has a tests field
117+
118+ // Try to extract mlir_example code blocks from the description field
119+ const llvm::RecordVal *descVal = record->getValue (" description" );
120+ if (descVal) {
121+ llvm::StringRef description = record->getValueAsString (" description" );
122+ llvm::errs () << " DEBUG: Record: " << record->getName () << " \n " ;
123+ llvm::errs () << " DEBUG: Description length: " << description.size () << " \n " ;
124+ llvm::errs () << " DEBUG: Description content:\n " << description << " \n " ;
125+ llvm::errs () << " DEBUG: ---\n " ;
126+ if (!description.empty ()) {
127+ llvm::SmallVector<std::string> examples = extractMlirExamples (description);
128+ llvm::errs () << " DEBUG: Found " << examples.size () << " examples\n " ;
129+
130+ // Create a LitTest for each extracted example
131+ for (size_t i = 0 ; i < examples.size (); ++i) {
132+ std::string testFileName;
133+ if (examples.size () == 1 ) {
134+ testFileName = " example.mlir" ;
135+ } else {
136+ testFileName = formatv (" example_{0}.mlir" , i);
137+ }
138+
139+ // Generate default RUN line with --verify-roundtrip
140+ llvm::SmallVector<std::string> runLines;
141+ runLines.push_back (" // RUN: mlir-opt %s --verify-roundtrip" );
142+
143+ tests.push_back (LitTest {
144+ record->getName ().str (),
145+ testFileName,
146+ examples[i],
147+ runLines,
148+ {} // No CHECK lines by default
149+ });
150+ }
151+ }
152+ }
153+
154+ // Fall back to checking for the old tests field for backward compatibility
54155 const llvm::RecordVal *testsVal = record->getValue (" tests" );
55156 if (!testsVal)
56157 return tests;
57-
58- const llvm::ListInit *testsList =
158+
159+ const llvm::ListInit *testsList =
59160 llvm::dyn_cast_or_null<llvm::ListInit>(testsVal->getValue ());
60161 if (!testsList)
61162 return tests;
62-
163+
63164 for (const llvm::Init *init : testsList->getElements ()) {
64165 const llvm::DefInit *defInit = llvm::dyn_cast<llvm::DefInit>(init);
65166 if (!defInit)
66167 continue ;
67-
168+
68169 const llvm::Record *testRec = defInit->getDef ();
69-
170+
70171 // Extract fields from LitTest record
71172 std::string name = testRec->getValueAsString (" testFileName" ).str ();
72173 std::string irSnippet = testRec->getValueAsString (" irSnippet" ).str ();
73-
174+
74175 llvm::SmallVector<std::string> runLines;
75176 llvm::for_each (*testRec->getValueAsListInit (" runLines" ), [&](const llvm::Init *init) {
76177 runLines.emplace_back (llvm::cast<llvm::StringInit>(init)->getValue ());
@@ -83,31 +184,31 @@ static llvm::SmallVector<LitTest> extractTestsFromRecord(const llvm::Record *rec
83184
84185 tests.push_back (LitTest {
85186 record->getName ().str (),
86- name,
87- irSnippet,
88- runLines,
89- checkLines,
187+ name,
188+ irSnippet,
189+ runLines,
190+ checkLines,
90191 });
91192 }
92-
193+
93194 return tests;
94195}
95196
96- // / Extract tests from passes
97- static llvm::SmallVector<LitTest> extractPassTests (const RecordKeeper &records) {
197+ // / Extract tests from ops
198+ static llvm::SmallVector<LitTest> extractOpTests (const RecordKeeper &records) {
98199 llvm::SmallVector<LitTest> tests;
99-
100- // Check if PassBase class exists before trying to get derived definitions
101- if (records.getClass (" PassBase " )) {
102- for (const llvm::Record *def : records.getAllDerivedDefinitions (" PassBase " )) {
200+
201+ // Check if Op class exists before trying to get derived definitions
202+ if (records.getClass (" Op " )) {
203+ for (const llvm::Record *def : records.getAllDerivedDefinitions (" Op " )) {
103204 if (def->isAnonymous ())
104205 continue ;
105-
106- auto passTests = extractTestsFromRecord (def, " passes " );
107- tests.insert (tests.end (), passTests .begin (), passTests .end ());
206+
207+ auto opTests = extractTestsFromRecord (def, " ops " );
208+ tests.insert (tests.end (), opTests .begin (), opTests .end ());
108209 }
109210 }
110-
211+
111212 return tests;
112213}
113214
@@ -132,26 +233,25 @@ static void generateTestFile(const LitTest &test, llvm::raw_ostream &os) {
132233// / Main function to generate all IR test test files
133234static void generateLitTests (const RecordKeeper &records, raw_ostream &os) {
134235 llvm::SmallVector<LitTest> allTests;
135-
136- // Extract tests from different definition types (only passes for now)
137- auto passTests = extractPassTests (records);
138-
139- allTests.insert (allTests.end (), passTests.begin (), passTests.end ());
140-
236+
237+ // Extract tests from different definition types
238+ auto opTests = extractOpTests (records);
239+ allTests.insert (allTests.end (), opTests.begin (), opTests.end ());
240+
141241 if (allTests.empty ()) {
142- os << " // No LitTest record found in any TableGen definition\n " ;
242+ os << " // No mlir_example code blocks found in any TableGen definition\n " ;
143243 return ;
144244 }
145-
245+
146246 // Generate summary
147247 os << " // Generated " << allTests.size () << " LIT test files\n " ;
148248 os << " // Use the following files for LIT testing:\n\n " ;
149-
249+
150250 // Generate file list and content for each test
151251 for (const auto & test : allTests) {
152252 std::string testFileName = formatv (" generated_{0}_{1}" , test.sourceDefName , test.testFileName );
153253 os << " // File: " << testFileName << " \n " ;
154-
254+
155255 os << " // --- BEGIN " << testFileName << " ---\n " ;
156256 generateTestFile (test, os);
157257 os << " // --- END " << testFileName << " ---\n\n " ;
0 commit comments