diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake index 6589458ab7894..9b05b70231dba 100644 --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -762,3 +762,103 @@ function(mlir_target_link_libraries target type) target_link_libraries(${target} ${type} ${ARGN}) endif() endfunction() + +# Extracts LIT tests embedded in `Testable` records in `tblgen_file` +# and generates a file per test in `output_dir` +# +# Example usage: +# # Extract tests from MyPasses.td and generate them in test/Passes/ +# add_embedded_lit_tests(MyPassesEmbeddedTests +# ${CMAKE_CURRENT_SOURCE_DIR}/include/MyPasses.td +# ${CMAKE_CURRENT_SOURCE_DIR}/test/Passes/) +# +# # This will: +# # 1. Process MyPasses.td with mlir-tblgen --gen-lit-tests +# # 2. Extract individual test files to test/Passes/ +# # 3. Generate files like: test/Passes/generated_MyPass_test1.mlir +# +function(add_embedded_lit_tests target tblgen_file output_dir) + set(LLVM_TARGET_DEFINITIONS ${tblgen_file}) + + # Extraction script content + set(EXTRACT_SCRIPT_CONTENT [[ + # Generated extraction script + if(NOT CONSOLIDATED_FILE) + message(FATAL_ERROR "CONSOLIDATED_FILE variable is required") + endif() + + if(NOT OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR variable is required") + endif() + + if(NOT EXISTS ${CONSOLIDATED_FILE}) + message(FATAL_ERROR "Consolidated file does not exist: ${CONSOLIDATED_FILE}") + endif() + + # Read the consolidated file + file(READ ${CONSOLIDATED_FILE} file_content) + + # Split into lines for processing + string(REPLACE "\n" ";" lines "${file_content}") + + set(current_filename "") + set(current_content "") + set(in_test_block FALSE) + set(extracted_test_files) + + foreach(line IN LISTS lines) + # Check for filename line + if(line MATCHES "^// File: (.+)$") + set(current_filename "${CMAKE_MATCH_1}") + endif() + + # Check for BEGIN marker + if(line MATCHES "^// --- BEGIN .+ ---$") + set(in_test_block TRUE) + set(current_content "") + # Check for END marker + elseif(line MATCHES "^// --- END .+ ---$") + set(in_test_block FALSE) + + # Write the extracted content to file + if(current_filename AND current_content) + file(MAKE_DIRECTORY ${OUTPUT_DIR}) + file(WRITE ${OUTPUT_DIR}/${current_filename} "${current_content}") + message(STATUS "Extracted test file: ${current_filename}") + list(APPEND extracted_test_files ${current_filename}) + endif() + + set(current_filename "") + set(current_content "") + # Collect content within BEGIN/END block + elseif(in_test_block) + string(APPEND current_content "${line}\n") + endif() + endforeach() + + list(LENGTH extracted_test_files num_extracted_files) + message(STATUS "Extracted ${num_extracted_files} test files to ${OUTPUT_DIR}") + ]]) + + # Write extraction script to a file in the build directory + file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/extract_lit_tests.cmake "${EXTRACT_SCRIPT_CONTENT}") + + # Process tblgen_file and generate a file with all embedded LIT + # tests in tblgen_file + get_filename_component(tblgen_name ${tblgen_file} NAME_WE) + set(consolidated_output_file ${tblgen_name}_extracted_lit_tests.txt) + mlir_tablegen(${consolidated_output_file} --gen-lit-tests) + + # Add public tablegen target to trigger builds on changes in tblgen_file + add_public_tablegen_target(${target}) + + # Call the extraction script to extract all LIT tests into individual + # `.mlir` test files + add_custom_command(TARGET ${target} POST_BUILD + COMMAND ${CMAKE_COMMAND} + -DCONSOLIDATED_FILE=${CMAKE_CURRENT_BINARY_DIR}/${consolidated_output_file} + -DOUTPUT_DIR=${output_dir} + -P ${CMAKE_CURRENT_BINARY_DIR}/extract_lit_tests.cmake + COMMENT "Extracting LIT tests to individual files" + ) +endfunction() \ No newline at end of file diff --git a/mlir/examples/toy/Ch2/CMakeLists.txt b/mlir/examples/toy/Ch2/CMakeLists.txt index 3fbff2fa2a679..7eb4c22adc296 100644 --- a/mlir/examples/toy/Ch2/CMakeLists.txt +++ b/mlir/examples/toy/Ch2/CMakeLists.txt @@ -1,6 +1,11 @@ # For a better template to copy, see examples/standalone add_subdirectory(include) +add_embedded_lit_tests(EmbeddedLitTestsGen + "include/toy/Ops.td" + "${CMAKE_CURRENT_BINARY_DIR}" +) + set(LLVM_LINK_COMPONENTS Support ) @@ -13,8 +18,9 @@ add_toy_chapter(toyc-ch2 DEPENDS ToyCh2OpsIncGen - + EmbeddedLitTestsGen ) + include_directories(include/) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) target_link_libraries(toyc-ch2 diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td index e37f9735e2241..d47780b735eea 100644 --- a/mlir/include/mlir/Pass/PassBase.td +++ b/mlir/include/mlir/Pass/PassBase.td @@ -14,6 +14,8 @@ #ifndef MLIR_PASS_PASSBASE #define MLIR_PASS_PASSBASE +include "mlir/IR/Testable.td" + //===----------------------------------------------------------------------===// // Options //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/gen-lit-tests.td b/mlir/test/mlir-tblgen/gen-lit-tests.td new file mode 100644 index 0000000000000..a377245d51450 --- /dev/null +++ b/mlir/test/mlir-tblgen/gen-lit-tests.td @@ -0,0 +1,74 @@ +// RUN: mlir-tblgen -gen-lit-tests -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "test"; +} + +def TestOp : Op { + let summary = "test op with mlir_example code blocks"; + let description = [{ + This operation demonstrates the mlir_example feature for ops. + + Basic usage: + ```mlir_example(mlir-opt) + func.func @foo(%arg0: i32) -> i32 { + %0 = test.test_op %arg0 : i32 + return %0 : i32 + } + ``` + + And some more back to back examples - + + ```mlir_example(some-other-tool) + func.func @foo1(%arg1: i32) -> i32 { + %0 = test.test_op %arg1 : i32 + return %0 : i32 + } + ``` + ```mlir_example(yet-another-tool) + func.func @foo2(%arg2: i32) -> i32 { + %0 = test.test_op %arg2 : i32 + return %0 : i32 + } + ``` + }]; + + let arguments = (ins I32:$input); + let results = (outs I32:$output); +} + +// CHECK-LABEL: // Generated 3 LIT test files +// CHECK: // Use the following files for LIT testing: + +// CHECK: // File: generated_TestOp_example_0.mlir +// CHECK: // --- BEGIN generated_TestOp_example_0.mlir --- +// CHECK: mlir-opt %s --verify-roundtrip +// CHECK: // Generated from TableGen definition: TestOp +// CHECK: func.func @foo(%arg0: i32) -> i32 { +// CHECK: %0 = test.test_op %arg0 : i32 +// CHECK: return %0 : i32 +// CHECK: } +// CHECK: // --- END generated_TestOp_example_0.mlir --- + +// CHECK: // File: generated_TestOp_example_1.mlir +// CHECK: // --- BEGIN generated_TestOp_example_1.mlir --- +// CHECK: some-other-tool %s --verify-roundtrip +// CHECK: // Generated from TableGen definition: TestOp +// CHECK: func.func @foo1(%arg1: i32) -> i32 { +// CHECK: %0 = test.test_op %arg1 : i32 +// CHECK: return %0 : i32 +// CHECK: } +// CHECK: // --- END generated_TestOp_example_1.mlir --- + +// CHECK: // File: generated_TestOp_example_2.mlir +// CHECK: // --- BEGIN generated_TestOp_example_2.mlir --- +// CHECK: yet-another-tool %s --verify-roundtrip +// CHECK: // Generated from TableGen definition: TestOp +// CHECK: func.func @foo2(%arg2: i32) -> i32 { +// CHECK: %0 = test.test_op %arg2 : i32 +// CHECK: return %0 : i32 +// CHECK: } +// CHECK: // --- END generated_TestOp_example_2.mlir --- \ No newline at end of file diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index 2a7ef7e0576c8..e721f1e26a2bd 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -16,6 +16,7 @@ add_tablegen(mlir-tblgen MLIR EnumsGen.cpp EnumPythonBindingGen.cpp FormatGen.cpp + LitTestGen.cpp LLVMIRConversionGen.cpp LLVMIRIntrinsicGen.cpp mlir-tblgen.cpp diff --git a/mlir/tools/mlir-tblgen/LitTestGen.cpp b/mlir/tools/mlir-tblgen/LitTestGen.cpp new file mode 100644 index 0000000000000..99bd86a20cb29 --- /dev/null +++ b/mlir/tools/mlir-tblgen/LitTestGen.cpp @@ -0,0 +1,197 @@ +//===- LitTestGen.cpp - LIT test generator ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// LitTestGen extracts `LitTest` records from `Testable` TableGen records and +// generates corresponding LIT test files. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Regex.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; +using llvm::formatv; +using llvm::RecordKeeper; + +static llvm::cl::OptionCategory litTestGenCategory("Options for -gen-lit-tests"); +static llvm::cl::opt + outputDir("output-dir", + llvm::cl::desc("Output directory for generated test files"), + llvm::cl::cat(litTestGenCategory), + llvm::cl::value_desc("directory")); + + +/// Cpp type corresponding to the `LitTest` record type in TableGen +struct LitTest { + std::string sourceDefName; + std::string testFileName; + std::string irSnippet; + llvm::SmallVector runLines; + llvm::SmallVector checkLines; +}; + +/// Extract code snippets with mlir_example tag from a description field. +/// Returns a vector of LitTest objects found within ```mlir_example ... ``` blocks. +static llvm::SmallVector extractOpTests(llvm::StringRef description, llvm::StringRef sourceDefName) { + llvm::SmallVector tests; + + // Pattern to match ```mlir_example ... ``` code blocks + // - ``` - Three literal backticks + // - `mlir_example` - Literal text + // - `(\(.+\))?` - Capture group matching the optional RUN tool name. Default is `mlir-opt`. + // - `([^`|`^`|``^`]+)` - Capture group matching the actual mlir IR example content (everything except for three consecutive backticks). + // - ``` - Three literal closing backticks + llvm::Regex codeBlockRegex("```mlir_example(\\([[:alnum:]_-]+\\))?[[:space:]]([^`|`^`|``^`]+)```"); + + auto remaining = description; + llvm::SmallVector matches; + + while (codeBlockRegex.match(remaining, &matches)) { + if (matches.size() == 3) { + std::string tool = "mlir-opt"; + // matches[1] contains the RUN tool name + if (!matches[1].empty()) { + tool = matches[1].ltrim('(').rtrim(')').str(); + } + + // matches[2] contains the code content + auto codeRef = matches[2]; + // Remove leading/trailing whitespace and comment markers (# prefix) + llvm::SmallVector lines; + codeRef.split(lines, '\n', -1, false); + + std::string processedCode; + for (llvm::StringRef line : lines) { + line = line.ltrim(); + // Remove leading # comment markers if present + if (line.starts_with("#")) { + line = line.drop_front(1).ltrim(); + } + if (!line.empty() || !processedCode.empty()) { + processedCode += line.str() + "\n"; + } + } + + if (!processedCode.empty()) { + // Generate test file name based on index + auto testFileName = formatv("example_{0}.mlir", tests.size()); + // Generate default RUN line with --verify-roundtrip + auto runLine = llvm::formatv("// RUN: {0} %s --verify-roundtrip", tool).str(); + + tests.push_back(LitTest{ + sourceDefName.str(), + testFileName, + processedCode, + {runLine}, + {} // No CHECK lines by default + }); + } + } + + // Move past this match to find the next one + size_t matchEnd = remaining.find("```", remaining.find("```mlir_example") + 15); + if (matchEnd == llvm::StringRef::npos) + break; + remaining = remaining.substr(matchEnd + 3); + } + + return tests; +} + +static llvm::SmallVector extractTestsFromRecord(const llvm::Record *record) { + llvm::SmallVector tests; + + // Try to extract mlir_example code blocks from the description field + const llvm::RecordVal *descVal = record->getValue("description"); + if (!descVal) + return tests; + + auto description = record->getValueAsString("description"); + if (description.empty()) + return tests; + + if (record->isSubClassOf("Op")) { + tests = extractOpTests(description, record->getName()); + } + + return tests; +} + +/// Generate a LIT test file for an IR test +static void generateTestFile(const LitTest &test, llvm::raw_ostream &os) { + // Add RUN lines + for (const auto& runLine : test.runLines) { + os << "\n" << runLine << "\n"; + } + + os << "// Generated from TableGen definition: " << test.sourceDefName << "\n\n"; + + // Add the test body + os << test.irSnippet << "\n"; + + // Add CHECK lines + for (const auto& checkLine : test.checkLines) { + os << "\n" << checkLine << "\n"; + } +} + +/// Main function to generate all IR test test files +static void generateLitTests(const RecordKeeper &records, raw_ostream &os) { + llvm::SmallVector allTests; + + // Extract tests from different definition types + if (records.getClass("Op")) { + for (const llvm::Record *def : records.getAllDerivedDefinitions("Op")) { + if (def->isAnonymous()) + continue; + + auto opTests = extractTestsFromRecord(def); + allTests.insert(allTests.end(), opTests.begin(), opTests.end()); + } + } + + if (allTests.empty()) { + os << "// No mlir_example code blocks found in any TableGen definition\n"; + return; + } + + // Generate summary + os << "// Generated " << allTests.size() << " LIT test files\n"; + os << "// Use the following files for LIT testing:\n\n"; + + // Generate file list and content for each test + for (const auto& test : allTests) { + std::string testFileName = formatv("generated_{0}_{1}", test.sourceDefName, test.testFileName); + os << "// File: " << testFileName << "\n"; + + os << "// --- BEGIN " << testFileName << " ---\n"; + generateTestFile(test, os); + os << "// --- END " << testFileName << " ---\n\n"; + } +} + +//===----------------------------------------------------------------------===// +// Generator Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genLitTests("gen-lit-tests", "Generate LIT test files for `Testable` TableGen records", + [](const RecordKeeper &records, raw_ostream &os) { + generateLitTests(records, os); + return false; + }); \ No newline at end of file