From ed33b750a0b51f9041b8cb6909fb04e44dd6613f Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 18 Jul 2025 16:39:56 -0400 Subject: [PATCH 01/63] Migrate to stablehlo --- .gitmodules | 9 +- mlir/CMakeLists.txt | 130 +- mlir/Makefile | 54 +- mlir/include/Catalyst/Transforms/Passes.td | 2 +- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 6 + .../Transforms/HloCustomCallPatterns.cpp | 13 +- .../Catalyst/Transforms/ScatterPatterns.cpp | 27 +- .../Transforms/hlo_custom_call_lowering.cpp | 4 +- .../Catalyst/Transforms/scatter_lowering.cpp | 6 +- mlir/lib/Driver/CMakeLists.txt | 8 +- mlir/lib/Driver/CompilerDriver.cpp | 10 +- mlir/lib/Driver/Pipelines.cpp | 16 +- mlir/mlir-hlo | 1 - .../mhlo-add-back-necessary-passes.patch | 1317 ----------------- mlir/patches/mhlo-remove-shardy.patch | 132 -- mlir/stablehlo | 1 + mlir/tools/catalyst-cli/CMakeLists.txt | 6 +- mlir/tools/quantum-lsp-server/CMakeLists.txt | 2 +- .../quantum-lsp-server/quantum-lsp-server.cpp | 2 - mlir/tools/quantum-opt/CMakeLists.txt | 6 +- mlir/tools/quantum-opt/quantum-opt.cpp | 17 +- 21 files changed, 170 insertions(+), 1599 deletions(-) delete mode 160000 mlir/mlir-hlo delete mode 100644 mlir/patches/mhlo-add-back-necessary-passes.patch delete mode 100644 mlir/patches/mhlo-remove-shardy.patch create mode 160000 mlir/stablehlo diff --git a/.gitmodules b/.gitmodules index 0148d21bdd..6dc71ddc43 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ -[submodule "mlir-hlo"] - path = mlir/mlir-hlo - url = https://github.com/tensorflow/mlir-hlo.git +[submodule "stablehlo"] + path = mlir/stablehlo + url = https://github.com/openxla/stablehlo.git shallow = true ignore = dirty [submodule "llvm-project"] @@ -13,3 +13,6 @@ url = https://github.com/EnzymeAD/Enzyme.git shallow = true ignore = dirty +[submodule "mlir/stablehlo"] + path = mlir/stablehlo + url = https://github.com/openxla/stablehlo.git diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 5322d486e8..e305486d23 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -29,11 +29,16 @@ endif() ######################### find_package(MLIR REQUIRED CONFIG) -find_package(MHLO REQUIRED CONFIG) +#find_package(STABLEHLO REQUIRED CONFIG) +# add_subdirectory(llvm-project/mlir/cmake/modules) + + +message("hi, stable hlo src dir: ${STABLEHLO_SRC_DIR}") +#include_directories(PUBLIC ${STABLEHLO_SRC_DIR}) message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -message(STATUS "Using MHLOConfig.cmake in: ${MHLO_DIR}") +#message(STATUS "Using STABLEHLOConfig.cmake in: ${STABLEHLO_DIR}") set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) @@ -42,20 +47,21 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) # Taken from mlir-hlo/mhlo/transforms/CMakeLists.txt. # Unfortunately, AllMhloPasses doesn't appear to be exported. set(ALL_MHLO_PASSES - ChloPasses - MhloPasses - StablehloPasses - MhloToArithmeticConversion - MhloToMemrefConversion - HloToLinalgUtils - MhloToLinalg - MhloToStablehlo - StablehloToMhlo + # ChloPasses + # MhloPasses + # StablehloPasses + # MhloToArithmeticConversion + # MhloToMemrefConversion + # HloToLinalgUtils + # MhloToLinalg + # MhloToStablehlo + # StablehloToMhlo + # StablehloPasses ) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${MHLO_CMAKE_DIR}") +# list(APPEND CMAKE_MODULE_PATH "${STABLEHLO_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") # Policy CMP0175 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. @@ -85,8 +91,8 @@ if(QUANTUM_ENABLE_BINDINGS_PYTHON) mlir_configure_python_dev_packages() endif() -list(GET MHLO_INCLUDE_DIRS 1 MLIRHLO_DIR) -list(GET MHLO_INCLUDE_DIRS 2 MLIRHLO_BUILD_DIR) +# list(GET STABLEHLO_INCLUDE_DIRS 1 MLIRHLO_DIR) +# list(GET STABLEHLO_INCLUDE_DIRS 2 MLIRHLO_BUILD_DIR) set(CATALYST_MAIN_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include) set(CATALYST_GEN_INCLUDE_DIR ${PROJECT_BINARY_DIR}/include) @@ -95,9 +101,9 @@ set(CATALYST_LIB_DIR ${PROJECT_BINARY_DIR}) include_directories(SYSTEM ${LLVM_INCLUDE_DIRS} ${MLIR_INCLUDE_DIRS} - ${MHLO_INCLUDE_DIRS} - ${MLIRHLO_DIR}/stablehlo - ${MLIRHLO_BUILD_DIR}/stablehlo + #${STABLEHLO_INCLUDE_DIRS} + #${MLIRHLO_DIR}/stablehlo + #${MLIRHLO_BUILD_DIR}/stablehlo ) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) @@ -162,53 +168,55 @@ add_subdirectory(cmake/modules) # Handle unittests when building out-of-tree against an installed version of # LLVM/MLIR (not a build tree). Adapted from `llvm/flang/CMakeLists.txt`. -set(CATALYST_GTEST_AVAILABLE 0) -if (TARGET llvm_gtest) - # Installed gtest, via LLVM_INSTALL_GTEST. Preferred. - message(STATUS "LLVM GTest found, enabling unittests") - set(CATALYST_GTEST_AVAILABLE 1) -else() - find_package(Threads REQUIRED) - set(LLVM_THIRD_PARTY_DIR llvm-project/third-party) - set(UNITTEST_DIR ${LLVM_THIRD_PARTY_DIR}/unittest) - if (NOT EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) - set(UNITTEST_DIR ${CMAKE_CURRENT_SOURCE_DIR}/llvm/third-party/unittest) - endif() - if (EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) - add_llvm_library(llvm_gtest - ${UNITTEST_DIR}/googletest/src/gtest-all.cc - ${UNITTEST_DIR}/googlemock/src/gmock-all.cc - LINK_COMPONENTS Support # llvm::raw_ostream - BUILDTREE_ONLY - ) - target_include_directories(llvm_gtest SYSTEM - PUBLIC - "${UNITTEST_DIR}/googletest/include" - "${UNITTEST_DIR}/googlemock/include" - PRIVATE - "${UNITTEST_DIR}/googletest" - "${UNITTEST_DIR}/googlemock" - ) - target_link_libraries(llvm_gtest PUBLIC Threads::Threads) - add_llvm_library(llvm_gtest_main - ${UNITTEST_DIR}/UnitTestMain/TestMain.cpp - LINK_LIBS llvm_gtest - LINK_COMPONENTS Support # llvm::cl - BUILDTREE_ONLY - ) - set(CATALYST_GTEST_AVAILABLE 1) - else() - message(WARNING "Skipping unittests since LLVM install does not include \ - gtest headers and libraries") - set(CATALYST_GTEST_AVAILABLE 0) - endif() -endif() -if (CATALYST_GTEST_AVAILABLE) - add_subdirectory(unittests) -endif() +# set(CATALYST_GTEST_AVAILABLE 0) +# if (TARGET llvm_gtest) +# # Installed gtest, via LLVM_INSTALL_GTEST. Preferred. +# message(STATUS "LLVM GTest found, enabling unittests") +# set(CATALYST_GTEST_AVAILABLE 1) +# else() +# find_package(Threads REQUIRED) +# set(LLVM_THIRD_PARTY_DIR llvm-project/third-party) +# set(UNITTEST_DIR ${LLVM_THIRD_PARTY_DIR}/unittest) +# if (NOT EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) +# set(UNITTEST_DIR ${CMAKE_CURRENT_SOURCE_DIR}/llvm/third-party/unittest) +# endif() +# if (EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) +# add_llvm_library(llvm_gtest +# ${UNITTEST_DIR}/googletest/src/gtest-all.cc +# ${UNITTEST_DIR}/googlemock/src/gmock-all.cc +# LINK_COMPONENTS Support # llvm::raw_ostream +# BUILDTREE_ONLY +# ) +# target_include_directories(llvm_gtest SYSTEM +# PUBLIC +# "${UNITTEST_DIR}/googletest/include" +# "${UNITTEST_DIR}/googlemock/include" +# PRIVATE +# "${UNITTEST_DIR}/googletest" +# "${UNITTEST_DIR}/googlemock" +# ) +# target_link_libraries(llvm_gtest PUBLIC Threads::Threads) +# add_llvm_library(llvm_gtest_main +# ${UNITTEST_DIR}/UnitTestMain/TestMain.cpp +# LINK_LIBS llvm_gtest +# LINK_COMPONENTS Support # llvm::cl +# BUILDTREE_ONLY +# ) +# set(CATALYST_GTEST_AVAILABLE 1) +# else() +# message(WARNING "Skipping unittests since LLVM install does not include \ +# gtest headers and libraries") +# set(CATALYST_GTEST_AVAILABLE 0) +# endif() +# endif() +# if (CATALYST_GTEST_AVAILABLE) +# add_subdirectory(unittests) +# endif() ###################### # End of CIRCT code # ###################### add_subdirectory(test) +unset(LLVM_USE_LINKER) +add_subdirectory(stablehlo) diff --git a/mlir/Makefile b/mlir/Makefile index 1e7d4fd832..d482d0300f 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -7,7 +7,7 @@ MK_ABSPATH := $(abspath $(lastword $(MAKEFILE_LIST))) MK_DIR := $(dir $(MK_ABSPATH)) DIALECTS_BUILD_DIR ?= $(MK_DIR)/build LLVM_BUILD_DIR ?= $(MK_DIR)/llvm-project/build -MHLO_BUILD_DIR ?= $(MK_DIR)/mlir-hlo/bazel-build +STABLEHLO_BUILD_DIR ?= $(MK_DIR)/stablehlo/build ENZYME_BUILD_DIR ?= $(MK_DIR)/Enzyme/build RT_BUILD_DIR ?= $(MK_DIR)/../runtime/build ENABLE_ASAN ?= OFF @@ -43,7 +43,7 @@ help: @echo "Please use \`make ' where is one of" @echo " all to build MLIR, MLIR-HLO and custom Catalyst dialects" @echo " llvm to build MLIR enabling Python bindings" - @echo " mhlo to build MLIR-HLO" + @echo " stablehlo to build stablehlo" @echo " enzyme to build Enzyme" @echo " dialects to build custom Catalyst MLIR dialects" @echo " test to run the Catalyst MLIR dialects test suite" @@ -52,7 +52,7 @@ help: @echo " format [version=?] to apply C++ formatter; use with 'version={version}' to run clang-format-{version} instead of clang-format" .PHONY: all -all: llvm mhlo enzyme dialects plugin +all: llvm stablehlo enzyme dialects plugin .PHONY: llvm llvm: @@ -83,36 +83,29 @@ llvm: # test to reduce unnecessary dependencies. LIT_FILTER_OUT="Bytecode|tosa-to-tensor|execution_engine" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS) -.PHONY: mhlo -mhlo: - @echo "build MLIR-HLO" - # Patch MHLO shardy dependency - @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-remove-shardy.patch; then \ - git apply $(MK_DIR)/patches/mhlo-remove-shardy.patch; \ - fi +.PHONY: stablehlo +stablehlo: + @echo "build stablehlo" - # Patch MHLO passes removed from upstream - @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; then \ - git apply $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; \ - fi - cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \ + cmake -G Ninja -S stablehlo -B $(STABLEHLO_BUILD_DIR) \ + -DSTABLEHLO_ENABLE_LLD=ON \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ - -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ - -DPython3_EXECUTABLE=$(PYTHON) \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ + -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DSTABLEHLO_ENABLE_SANITIZER=address \ -DCMAKE_C_COMPILER=$(C_COMPILER) \ -DCMAKE_CXX_COMPILER=$(CXX_COMPILER) \ -DCMAKE_C_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ -DCMAKE_CXX_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ -DCMAKE_EXE_LINKER_FLAGS=$(USE_SANITIZER_FLAGS) \ - -DLLVM_ENABLE_LLD=$(ENABLE_LLD) \ - -DLLVM_ENABLE_ZLIB=$(ENABLE_ZLIB) \ - -DLLVM_ENABLE_ZSTD=$(ENABLE_ZSTD) \ -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) - # TODO: figure out why this test is failing - LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build $(MHLO_BUILD_DIR) --target check-mlir-hlo + cmake --build $(STABLEHLO_BUILD_DIR) + #ninja check-stablehlo-tests .PHONY: enzyme enzyme: TARGET_FILE := $(MK_DIR)/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -172,8 +165,8 @@ dialects: -DEnzyme_DIR=$(ENZYME_BUILD_DIR) \ -DENZYME_SRC_DIR=$(MK_DIR)/Enzyme \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ - -DMHLO_DIR=$(MHLO_BUILD_DIR)/lib/cmake/mlir-hlo \ - -DMHLO_BINARY_DIR=$(MHLO_BUILD_DIR)/bin \ + -DSTABLEHLO_SRC_DIR=$(MK_DIR)/stablehlo \ + -DSTABLEHLO_BINARY_DIR=$(STABLEHLO_BUILD_DIR)/bin \ -DRUNTIME_LIB_DIR=$(RT_BUILD_DIR)/lib \ -DMLIR_LIB_DIR=$(LLVM_BUILD_DIR)/lib \ -DCMAKE_C_COMPILER=$(C_COMPILER) \ @@ -186,15 +179,15 @@ dialects: -DLLVM_ENABLE_ZSTD=$(ENABLE_ZSTD) \ -DCATALYST_ENABLE_WARNINGS=$(STRICT_WARNINGS) - cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects quantum-lsp-server catalyst-cli check-unit-tests + cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects quantum-lsp-server catalyst-cli #check-unit-tests .PHONY: test test: @echo "test the Catalyst MLIR dialects test suite" cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects -.PHONY: clean clean-dialects clean-enzyme clean-mhlo clean-plugin -clean: clean-dialects clean-llvm clean-mhlo clean-enzyme clean-plugin +.PHONY: clean clean-dialects clean-enzyme clean-stablehlo clean-plugin +clean: clean-dialects clean-llvm clean-stablehlo clean-enzyme clean-plugin clean-dialects: @echo "clean catalyst dialect build files" @@ -204,10 +197,9 @@ clean-llvm: @echo "clean llvm/mlir build files" rm -rf $(LLVM_BUILD_DIR) -clean-mhlo: - @echo "clean HLO dialect build files" - rm -rf $(MHLO_BUILD_DIR) - cd mlir-hlo; git clean -fd; git checkout . +clean-stablehlo: + @echo "clean stablehlo dialect build files" + rm -rf $(STABLEHLO_BUILD_DIR) clean-enzyme: @echo "clean enzyme build files" diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index e1512e00e0..e731cecddc 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -61,7 +61,7 @@ def ScatterLoweringPass : Pass<"scatter-lowering"> { let dependentDialects = [ "mlir::func::FuncDialect", "index::IndexDialect", - "mhlo::MhloDialect", + "stablehlo::StablehloDialect", "tensor::TensorDialect", "scf::SCFDialect" ]; diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index b4776af6a9..bdffe06b24 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -34,10 +34,16 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + StablehloPasses + StablehloOps ) set(DEPENDS MLIRCatalystPassIncGen + StablehloBaseIncGen + #StablehloBaseIncGen + # StablehloPasses + # StablehloOps ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) diff --git a/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp b/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp index 019abc33ce..3b1fe91149 100644 --- a/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp @@ -14,22 +14,23 @@ #define DEBUG_TYPE "scatter" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +#include "stablehlo/dialect/StablehloOps.h" #include "Catalyst/IR/CatalystOps.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "llvm/Support/Debug.h" using namespace mlir; namespace catalyst { -struct HloCustomCallOpRewritePattern : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; +struct HloCustomCallOpRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite(mhlo::CustomCallOp op, + mlir::LogicalResult matchAndRewrite(stablehlo::CustomCallOp op, mlir::PatternRewriter &rewriter) const override { StringRef calleeName = op.getCallTargetName(); diff --git a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp index 3a95533737..eeb462b399 100644 --- a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp @@ -23,16 +23,17 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" + using namespace mlir; namespace catalyst { -struct ScatterOpRewritePattern : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; +struct ScatterOpRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - void emitIndicesError(mhlo::ScatterOp op) const + void emitIndicesError(stablehlo::ScatterOp op) const { op.emitError() << "Indices are not unique and/or not sorted. Note that when using multiple indices " @@ -44,7 +45,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern << ", sorted: " << op.getIndicesAreSorted(); } - mlir::LogicalResult onlyOneInputUpdateAndResult(mhlo::ScatterOp op) const + mlir::LogicalResult onlyOneInputUpdateAndResult(stablehlo::ScatterOp op) const { // Semantics of scatter: // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter @@ -64,7 +65,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return op.getResults().size() == 1 ? success() : failure(); } - mlir::LogicalResult isAssignment(mhlo::ScatterOp op) const + mlir::LogicalResult isAssignment(stablehlo::ScatterOp op) const { // From: // C23: update_computation has type @@ -90,7 +91,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return failure(); } - mhlo::ReturnOp returnOp = dyn_cast(block.getTerminator()); + stablehlo::ReturnOp returnOp = dyn_cast(block.getTerminator()); if (!returnOp) { return failure(); } @@ -98,7 +99,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return returnOp.getResults().front() == block.getArgument(1) ? success() : failure(); } - mlir::LogicalResult noBatching(mhlo::ScatterOp op) const + mlir::LogicalResult noBatching(stablehlo::ScatterOp op) const { // Ok, now that we know it is an assignment, we need to worry about // where exactly are we assigning and what are we assigning. @@ -123,7 +124,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern // return scatterDimNumbers.getInputBatchingDims().empty() ? success() : failure(); } - mlir::LogicalResult singleFullSlices(mhlo::ScatterOp op) const + mlir::LogicalResult singleFullSlices(stablehlo::ScatterOp op) const { // From: // More formally, for all update_index in index_space(updates[0]): @@ -144,13 +145,13 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return rank == scatterDimNumbers.getUpdateWindowDims().size() ? success() : failure(); } - mlir::LogicalResult canBeDoneWithSingleTensorInsertSlice(mhlo::ScatterOp op) const + mlir::LogicalResult canBeDoneWithSingleTensorInsertSlice(stablehlo::ScatterOp op) const { return cast(op.getScatterIndices().getType()).getRank() == 1 ? success() : failure(); } - mlir::LogicalResult lowerToTensorInsertSlice(mhlo::ScatterOp op, + mlir::LogicalResult lowerToTensorInsertSlice(stablehlo::ScatterOp op, mlir::PatternRewriter &rewriter) const { // mhlo::ScatterOp is exactly the same as stablehlo::ScatterOp @@ -284,7 +285,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern return success(); } - mlir::LogicalResult matchAndRewrite(mhlo::ScatterOp op, + mlir::LogicalResult matchAndRewrite(stablehlo::ScatterOp op, mlir::PatternRewriter &rewriter) const override { // FastPath @@ -457,7 +458,7 @@ struct ScatterOpRewritePattern : public mlir::OpRewritePattern }; // Store all the necessary variables for the SCF for op in above defined struct - UpdateData getUpdateData(mhlo::ScatterOp &op, mlir::PatternRewriter &rewriter, + UpdateData getUpdateData(stablehlo::ScatterOp &op, mlir::PatternRewriter &rewriter, mlir::Location loc) const { UpdateData data; diff --git a/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp b/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp index 290d3827a6..b2bea6cb65 100644 --- a/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp +++ b/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp @@ -18,8 +18,8 @@ #include "llvm/Support/Debug.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" diff --git a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp b/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp index e56b35390a..2a95de11c9 100644 --- a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp +++ b/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp @@ -18,8 +18,10 @@ #include "llvm/Support/Debug.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" +// #include "mhlo/IR/hlo_ops.h" +// #include "mhlo/transforms/passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 1bb5720366..746ddfdcb7 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -40,10 +40,10 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - MhloRegisterDialects + #MhloRegisterDialects StablehloRegister MLIRCatalystTest - ${ALL_MHLO_PASSES} + #${ALL_MHLO_PASSES} ${ENZYME_LIB} ) @@ -52,6 +52,10 @@ add_mlir_library(CatalystCompilerDriver CatalystLLVMTarget.cpp Pipelines.cpp + DEPENDS + StablehloBaseIncGen + LINK_LIBS PRIVATE ${LIBS} + #StablehloBase ) diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index d72ef39ebb..beedf0695e 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -24,8 +24,10 @@ #include #include -#include "mhlo/IR/register.h" -#include "mhlo/transforms/passes.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/dialect/Register.h" +#include "stablehlo/integrations/c/StablehloPasses.h" + #include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" @@ -34,7 +36,6 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/LLVMIR/Export.h" -#include "stablehlo/dialect/Register.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/IR/LegacyPassManager.h" @@ -294,7 +295,6 @@ void registerAllCatalystDialects(DialectRegistry ®istry) registerAllExtensions(registry); // HLO - mhlo::registerAllMhloDialects(registry); stablehlo::registerAllDialects(registry); // Catalyst @@ -962,7 +962,7 @@ int QuantumDriverMainFromCL(int argc, char **argv) registerAllPasses(); registerAllCatalystPasses(); registerAllCatalystPipelines(); - mhlo::registerAllMhloPasses(); + mlirRegisterAllStablehloPasses(); registerAllCatalystDialects(registry); registerLLVMTranslations(registry); diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 46f6905a66..35f74a1e1e 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -20,7 +20,7 @@ #include "Mitigation/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/Passes.h" -#include "mhlo/transforms/passes.h" +#include "stablehlo/transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -40,13 +40,13 @@ void createEnforceRuntimeInvariantsPipeline(OpPassManager &pm) void createHloLoweringPipeline(OpPassManager &pm) { pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass(mhlo::createLegalizeControlFlowPass()); - pm.addNestedPass(mhlo::createLegalizeHloToLinalgPass()); - pm.addNestedPass(mhlo::createLegalizeToStdPass()); - pm.addNestedPass(mhlo::createLegalizeSortPass()); - pm.addPass(mlir::mhlo::createConvertToSignlessPass()); + //pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); + //pm.addPass(stablehlo::createStablehloLegalizeToHloPass()); + //pm.addNestedPass(stablehlo::createLegalizeControlFlowPass()); + //pm.addNestedPass(stablehlo::createLegalizeHloToLinalgPass()); + //pm.addNestedPass(stablehlo::createLegalizeToStdPass()); + //pm.addNestedPass(stablehlo::createLegalizeSortPass()); + //pm.addPass(stablehlo::createConvertToSignlessPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createScatterLoweringPass()); pm.addPass(catalyst::createHloCustomCallLoweringPass()); diff --git a/mlir/mlir-hlo b/mlir/mlir-hlo deleted file mode 160000 index 617a9361d1..0000000000 --- a/mlir/mlir-hlo +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 617a9361d186199480c080c9e8c474a5e30c22d1 diff --git a/mlir/patches/mhlo-add-back-necessary-passes.patch b/mlir/patches/mhlo-add-back-necessary-passes.patch deleted file mode 100644 index b56ede8dd5..0000000000 --- a/mlir/patches/mhlo-add-back-necessary-passes.patch +++ /dev/null @@ -1,1317 +0,0 @@ -From b1728b65b1511cd5ef3e11650b9e416d3fad068f Mon Sep 17 00:00:00 2001 -From: paul0403 -Date: Thu, 29 May 2025 11:00:56 -0400 -Subject: [PATCH] restore the removed mhlo passes we need: - mhlo-legalize-control-flow, mhlo-legalize-to-std, hlo-legalize-sort - ---- - mhlo/transforms/CMakeLists.txt | 6 + - .../legalize_control_flow.cc | 288 +++++++++ - .../transforms/legalize_sort/legalize_sort.cc | 577 ++++++++++++++++++ - .../legalize_to_standard.cc | 243 ++++++++ - .../legalize_to_standard_patterns.td | 92 +++ - mhlo/transforms/mhlo_passes.td | 19 + - mhlo/transforms/passes.h | 4 + - 7 files changed, 1229 insertions(+) - create mode 100644 mhlo/transforms/legalize_control_flow/legalize_control_flow.cc - create mode 100644 mhlo/transforms/legalize_sort/legalize_sort.cc - create mode 100644 mhlo/transforms/legalize_to_standard/legalize_to_standard.cc - create mode 100644 mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td - -diff --git a/mhlo/transforms/CMakeLists.txt b/mhlo/transforms/CMakeLists.txt -index d6848633..26d3b419 100644 ---- a/mhlo/transforms/CMakeLists.txt -+++ b/mhlo/transforms/CMakeLists.txt -@@ -26,14 +26,20 @@ set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.t - mlir_tablegen(chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc -gen-rewriters) - add_public_tablegen_target(MLIRChloLegalizeToHloIncGen) - -+set(LLVM_TARGET_DEFINITIONS legalize_to_standard/legalize_to_standard_patterns.td) -+mlir_tablegen(legalize_to_standard/generated_legalize_to_standard.inc -gen-rewriters) -+add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) - - - add_mlir_library(MhloPasses - collapse_elementwise_map/collapse_elementwise_map.cc - convert_to_signless/convert_to_signless_pass.cc - expand_hlo_tuples/expand_hlo_tuples.cc -+ legalize_control_flow/legalize_control_flow.cc - legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc - legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc -+ legalize_sort/legalize_sort.cc -+ legalize_to_standard/legalize_to_standard.cc - legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc - legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc - materialize_broadcasts/materialize_broadcasts.cc -diff --git a/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc b/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc -new file mode 100644 -index 00000000..9d473b9a ---- /dev/null -+++ b/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc -@@ -0,0 +1,288 @@ -+/* Copyright 2019 The OpenXLA Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+// This file implements logic for lowering MHLO dialect to SCF dialect. -+#include -+#include -+#include -+ -+#include "llvm/Support/Casting.h" -+#include "mhlo/IR/hlo_ops.h" -+#include "mhlo/transforms/passes.h" -+#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/Dialect/SCF/IR/SCF.h" -+#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project -+#include "mlir/IR/Block.h" -+#include "mlir/IR/Builders.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" -+#include "mlir/IR/PatternMatch.h" -+#include "mlir/IR/TypeUtilities.h" -+#include "mlir/Pass/Pass.h" -+#include "mlir/Support/LLVM.h" -+#include "mlir/Support/LogicalResult.h" -+#include "mlir/Transforms/DialectConversion.h" -+ -+namespace mlir { -+namespace mhlo { -+ -+#define GEN_PASS_DEF_LEGALIZECONTROLFLOWPASS -+#include "mhlo/transforms/mhlo_passes.h.inc" -+ -+namespace { -+ -+// All transformations in this file take mhlo blocks which end with -+// mhlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an -+// entire block with the only change being return -> yield. -+void inlineMhloRegionIntoSCFRegion(PatternRewriter& rewriter, Region& mhlo, -+ Region& scf) { -+ // Remove an existing block, then move the region over. -+ if (!scf.empty()) rewriter.eraseBlock(&scf.back()); -+ rewriter.inlineRegionBefore(mhlo, scf, scf.end()); -+ // Fix up the terminator. -+ PatternRewriter::InsertionGuard guard(rewriter); -+ rewriter.setInsertionPointToEnd(&scf.back()); -+ auto* terminator = scf.back().getTerminator(); -+ rewriter.replaceOpWithNewOp(terminator, -+ terminator->getOperands()); -+} -+ -+// mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor -+// or a 1 element tensor. To handle this, collapse shape before extracting the -+// scalar value when necessary. -+Value extractTensorValue(OpBuilder& b, Value tensor) { -+ auto loc = tensor.getLoc(); -+ if (mlir::cast(tensor.getType()).hasRank() && -+ mlir::cast(tensor.getType()).getRank() != 0) { -+ tensor = b.create( -+ loc, tensor, SmallVector()); -+ } -+ return b.create(loc, tensor, ValueRange()); -+} -+ -+struct ScfForBounds { -+ Value lb; -+ Value ub; -+ Value step; -+ unsigned indexArgIndex; -+}; -+ -+std::optional extractForBounds(mhlo::WhileOp op) { -+ auto& cond = op.getCond().front(); -+ auto& body = op.getBody().front(); -+ if (cond.getOperations().size() != 2) return std::nullopt; -+ -+ auto matchBbArg = [](Value v, Block& block) -> std::optional { -+ if (!mlir::isa(v) || v.getParentBlock() != &block) -+ return std::nullopt; -+ return mlir::cast(v).getArgNumber(); -+ }; -+ -+ auto compare = llvm::dyn_cast(cond.front()); -+ // If the rhs of the comapare is defined outside the block, it's a constant -+ // within the loop. -+ if (!compare || -+ compare.getComparisonDirection() != mhlo::ComparisonDirection::LT || -+ compare.getRhs().getParentBlock() == &cond || -+ !getElementTypeOrSelf(compare.getLhs().getType()) -+ .isSignlessIntOrIndex()) { -+ return std::nullopt; -+ } -+ -+ auto iterArg = matchBbArg(compare.getLhs(), cond); -+ if (!iterArg) return std::nullopt; -+ -+ auto add = llvm::dyn_cast_or_null( -+ body.getTerminator()->getOperand(*iterArg).getDefiningOp()); -+ if (!add || matchBbArg(add.getLhs(), body) != iterArg || -+ add.getRhs().getParentBlock() == &body) { -+ return std::nullopt; -+ } -+ -+ ScfForBounds bounds; -+ bounds.ub = compare.getRhs(); -+ bounds.step = add.getRhs(); -+ bounds.lb = op->getOperand(*iterArg); -+ bounds.indexArgIndex = *iterArg; -+ return bounds; -+} -+ -+// Rewrites `mhlo.while` to `scf.while` or `scf.for`. -+struct WhileOpPattern : public OpConversionPattern { -+ using OpConversionPattern::OpConversionPattern; -+ -+ LogicalResult matchAndRewrite( -+ mhlo::WhileOp op, OpAdaptor adaptor, -+ ConversionPatternRewriter& rewriter) const override { -+ auto loc = op.getLoc(); -+ -+ if (auto bounds = extractForBounds(op)) { -+ auto newForOp = rewriter.create( -+ loc, extractTensorValue(rewriter, bounds->lb), -+ extractTensorValue(rewriter, bounds->ub), -+ extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); -+ -+ rewriter.setInsertionPointToEnd(newForOp.getBody()); -+ // Inline while body, and only replace the mhlo.return with an scf.yield. -+ inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), -+ newForOp.getRegion()); -+ auto indexArg = newForOp.getRegion().insertArgument( -+ unsigned{0}, newForOp.getLowerBound().getType(), loc); -+ auto oldIndexArg = -+ newForOp.getRegion().getArgument(1 + bounds->indexArgIndex); -+ rewriter.setInsertionPointToStart(&newForOp.getRegion().front()); -+ auto indexArgTensor = rewriter.create( -+ loc, oldIndexArg.getType(), indexArg); -+ oldIndexArg.replaceAllUsesWith(indexArgTensor); -+ -+ rewriter.replaceOp(op, newForOp.getResults()); -+ return success(); -+ } -+ -+ auto newWhileOp = rewriter.create(loc, op.getResultTypes(), -+ adaptor.getOperands()); -+ -+ // Inline while condition. The block is the same, except the boolean result -+ // needs to be extracted and used with an scf.condition. -+ rewriter.inlineRegionBefore(op.getCond(), newWhileOp.getBefore(), -+ newWhileOp.getBefore().end()); -+ auto conditionReturn = -+ cast(newWhileOp.getBefore().front().getTerminator()); -+ rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); -+ Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0)); -+ rewriter.replaceOpWithNewOp( -+ conditionReturn, i1, newWhileOp.getBeforeArguments()); -+ -+ // Inline while body, and only replace the mhlo.return with an scf.yield. -+ inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), -+ newWhileOp.getAfter()); -+ -+ rewriter.replaceOp(op, newWhileOp.getResults()); -+ return success(); -+ } -+}; -+ -+// Rewrites `mhlo.if` to `scf.if`. -+struct IfOpPattern : public OpConversionPattern { -+ using OpConversionPattern::OpConversionPattern; -+ -+ LogicalResult matchAndRewrite( -+ mhlo::IfOp op, OpAdaptor adaptor, -+ ConversionPatternRewriter& rewriter) const override { -+ auto scfIf = rewriter.create( -+ op.getLoc(), op.getResultTypes(), -+ extractTensorValue(rewriter, adaptor.getPred()), -+ /*withElseRegion=*/true); -+ inlineMhloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), -+ scfIf.getThenRegion()); -+ inlineMhloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), -+ scfIf.getElseRegion()); -+ rewriter.replaceOp(op, scfIf.getResults()); -+ return success(); -+ } -+}; -+ -+// Rewrites `mhlo.case` to a nested `scf.if`. -+struct CaseOpPattern : public OpConversionPattern { -+ using OpConversionPattern::OpConversionPattern; -+ -+ // Recursively create if/else ops to handle each possible value in a case op. -+ scf::IfOp createNestedCases(int currentIdx, CaseOp op, OpAdaptor adaptor, -+ PatternRewriter& outerBuilder) const { -+ Location loc = op.getLoc(); -+ Value idxValue = adaptor.getIndex(); -+ auto finalIdx = op.getBranches().size() - 2; -+ -+ // Determine if the current index matches the case index. -+ auto scalarType = idxValue.getType(); -+ auto shapedType = mlir::cast(scalarType); -+ auto constAttr = DenseElementsAttr::get( -+ shapedType, {mlir::cast( -+ outerBuilder.getI32IntegerAttr(currentIdx))}); -+ Value currentIdxVal = outerBuilder.create( -+ loc, idxValue.getType(), constAttr); -+ -+ auto scfIf = outerBuilder.create( -+ loc, op.getResultTypes(), -+ extractTensorValue(outerBuilder, outerBuilder.create( -+ loc, idxValue, currentIdxVal, -+ ComparisonDirection::EQ)), -+ /*withElseRegion=*/true); -+ inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], -+ scfIf.getThenRegion()); -+ int nextIdx = currentIdx + 1; -+ // Don't recurse for the final default block. -+ if (currentIdx == static_cast(finalIdx)) { -+ inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], -+ scfIf.getElseRegion()); -+ } else { -+ PatternRewriter::InsertionGuard guard(outerBuilder); -+ outerBuilder.setInsertionPointToEnd(&scfIf.getElseRegion().back()); -+ auto innerIf = createNestedCases(nextIdx, op, adaptor, outerBuilder); -+ outerBuilder.create(op.getLoc(), innerIf.getResults()); -+ } -+ return scfIf; -+ } -+ -+ LogicalResult matchAndRewrite( -+ mhlo::CaseOp op, OpAdaptor adaptor, -+ ConversionPatternRewriter& rewriter) const override { -+ // Inline the op if there is only a default block. -+ if (op.getBranches().size() == 1) { -+ Block& block = op.getBranches().front().front(); -+ auto results = block.getTerminator()->getOperands(); -+ // Remove the mhlo.return terminator, then inline the block. -+ rewriter.eraseOp(block.getTerminator()); -+ rewriter.inlineBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(), -+ /*argValues=*/{}); -+ rewriter.replaceOp(op, results); -+ return success(); -+ } -+ -+ // Begin recursion with case 0. -+ rewriter.replaceOp( -+ op, createNestedCases(0, op, adaptor, rewriter).getResults()); -+ return success(); -+ } -+}; -+ -+struct LegalizeControlFlowPass -+ : public impl::LegalizeControlFlowPassBase { -+ // Perform the lowering to MLIR control flow. -+ void runOnOperation() override { -+ func::FuncOp f = getOperation(); -+ MLIRContext* ctx = f.getContext(); -+ -+ RewritePatternSet patterns(&getContext()); -+ patterns.add(&getContext()); -+ -+ mlir::ConversionTarget target(*ctx); -+ target.markUnknownOpDynamicallyLegal([](Operation*) { return true; }); -+ target.addIllegalOp(); -+ -+ if (failed(applyPartialConversion(f, target, std::move(patterns)))) { -+ signalPassFailure(); -+ } -+ } -+}; -+ -+} // namespace -+} // namespace mhlo -+} // namespace mlir -+ -+std::unique_ptr> -+mlir::mhlo::createLegalizeControlFlowPass() { -+ return std::make_unique(); -+} -diff --git a/mhlo/transforms/legalize_sort/legalize_sort.cc b/mhlo/transforms/legalize_sort/legalize_sort.cc -new file mode 100644 -index 00000000..8ba9de9a ---- /dev/null -+++ b/mhlo/transforms/legalize_sort/legalize_sort.cc -@@ -0,0 +1,577 @@ -+/* Copyright 2019 The OpenXLA Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+// This file implements logic for lowering mhlo.sort to the SCF dialect. -+#include -+#include -+#include -+ -+#include "llvm/ADT/STLExtras.h" -+#include "mhlo/IR/hlo_ops.h" -+#include "mhlo/transforms/passes.h" -+#include "mlir/Dialect/Arith/IR/Arith.h" -+#include "mlir/Dialect/Arith/Utils/Utils.h" -+#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -+#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/Dialect/MemRef/IR/MemRef.h" -+#include "mlir/Dialect/SCF/IR/SCF.h" -+#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project -+#include "mlir/IR/Block.h" -+#include "mlir/IR/Builders.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/IRMapping.h" -+#include "mlir/IR/ImplicitLocOpBuilder.h" -+#include "mlir/IR/Location.h" -+#include "mlir/IR/PatternMatch.h" -+#include "mlir/IR/TypeRange.h" -+#include "mlir/IR/ValueRange.h" -+#include "mlir/Pass/Pass.h" -+#include "mlir/Support/LLVM.h" -+#include "mlir/Support/LogicalResult.h" -+#include "mlir/Transforms/DialectConversion.h" -+ -+namespace mlir { -+namespace mhlo { -+ -+#define GEN_PASS_DEF_HLOLEGALIZESORTPASS -+#include "mhlo/transforms/mhlo_passes.h.inc" -+ -+namespace { -+ -+using ::mlir::arith::AddIOp; -+using ::mlir::arith::MinSIOp; -+using ::mlir::arith::SelectOp; -+ -+constexpr int64_t kInsertionSortSize = 16; -+ -+// Inlines the `comparator` region (without terminator) at the current insertion -+// point, replacing the arguments with the given values from `lhs` and `rhs`. -+Value emitComparison(ImplicitLocOpBuilder& b, SmallVector& lhs, -+ SmallVector& rhs, Region& comparator) { -+ assert(comparator.hasOneBlock() && "Comparator must have only one block."); -+ Block& block = comparator.front(); -+ assert(block.getTerminator()->getOperands().size() == 1 && -+ "Comparator must return a single value"); -+ -+ IRMapping mapping; -+ for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { -+ Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; -+ Type type = RankedTensorType::get({}, value.getType()); -+ mapping.map(arg, b.create(type, value)); -+ } -+ -+ for (Operation& op : block.without_terminator()) b.clone(op, mapping); -+ Value result = mapping.lookup(block.getTerminator()->getOperands().front()); -+ -+ return b.create(result, ValueRange()); -+} -+ -+// Emits a binary search of `pivots` in `arrayMemrefs` (all rank 1) in the range -+// [`left`;`right`). `arrayMemrefs` must be sorted according to `comparator`. -+Value emitBinarySearch(ImplicitLocOpBuilder& b, Value leftInit, Value rightInit, -+ SmallVector& pivots, ValueRange arrayMemrefs, -+ Region& comparator) { -+ SmallVector types{leftInit.getType(), rightInit.getType()}; -+ ArithBuilder arith(b, b.getLoc()); -+ -+ // while ( -+ auto whileOp = -+ b.create(types, SmallVector{leftInit, rightInit}); -+ OpBuilder::InsertionGuard guard(b); -+ -+ // left < right) { -+ Block* before = b.createBlock(&whileOp.getBefore(), {}, types, -+ {whileOp.getLoc(), whileOp.getLoc()}); -+ { -+ Value left = before->getArgument(0), right = before->getArgument(1); -+ b.setInsertionPointToEnd(before); -+ b.create(arith.slt(left, right), before->getArguments()); -+ } -+ -+ Block* after = b.createBlock(&whileOp.getAfter(), {}, types, -+ {whileOp.getLoc(), whileOp.getLoc()}); -+ { -+ Value left = after->getArgument(0), right = after->getArgument(1); -+ b.setInsertionPointToEnd(after); -+ // int mid = (left + right) >> 1; -+ Value one = b.create(1); -+ Value mid = b.create(arith.add(left, right), one); -+ Value midPlusOne = b.create(mid, one); -+ -+ auto arraysAtMid = llvm::to_vector( -+ llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { -+ return b.create(arrayMemref, mid); -+ })); -+ Value cond = emitComparison(b, pivots, arraysAtMid, comparator); -+ // if (comparator(pivot, array[mid])) -+ // right = mid; -+ // else -+ // left = mid + 1; -+ Value newLeft = arith.select(cond, left, midPlusOne); -+ Value newRight = arith.select(cond, mid, right); -+ -+ // } -+ b.create(ValueRange{newLeft, newRight}); -+ } -+ -+ return whileOp.getResult(0); -+} -+ -+SmallVector loadTensorElements(ImplicitLocOpBuilder& b, -+ ValueRange tensors, Value index) { -+ return llvm::to_vector(llvm::map_range(tensors, [&](Value tensor) -> Value { -+ return b.create(tensor, index); -+ })); -+} -+ -+SmallVector loadMemrefElements(ImplicitLocOpBuilder& b, -+ ValueRange memrefs, Value index) { -+ return llvm::to_vector(llvm::map_range(memrefs, [&](Value memref) -> Value { -+ Type type = mlir::cast(memref.getType()).getElementType(); -+ return b.create(type, memref, index); -+ })); -+} -+ -+void storeMemrefElements(ImplicitLocOpBuilder& b, ValueRange memrefs, -+ Value index, ValueRange values) { -+ for (auto [value, memref] : llvm::zip(values, memrefs)) { -+ b.create(value, memref, index); -+ } -+} -+ -+// Insertion sorts `inputTensors` in the range [`lo`; `hi`), storing the results -+// in `outputMemrefs`. `inputTensors` and `outputMemrefs` must all be rank 1 and -+// of identical size. -+void emitInsertionSort(ImplicitLocOpBuilder& b, Value lo, Value hi, -+ ValueRange inputTensors, ValueRange outputMemrefs, -+ mlir::Region& comparator) { -+ ArithBuilder arith(b, b.getLoc()); -+ Value zero = b.create(0); -+ Value one = b.create(1); -+ -+ // array[lo] = tensors[lo]; -+ storeMemrefElements(b, outputMemrefs, lo, -+ loadTensorElements(b, inputTensors, lo)); -+ -+ // for (int start = lo + 1; start < hi; ++start) -+ { -+ auto forOp = b.create(arith.add(lo, one), hi, one); -+ OpBuilder::InsertionGuard outerGuard(b); -+ b.setInsertionPointToStart(forOp.getBody()); -+ Value start = forOp.getInductionVar(); -+ -+ // T pivot = tensors[start]; -+ auto pivots = loadTensorElements(b, inputTensors, start); -+ -+ // int index = binarySearch(lo, start, pivot, array, comparator); -+ auto index = -+ emitBinarySearch(b, lo, start, pivots, outputMemrefs, comparator); -+ -+ // int n = start - index; // The number of elements to move -+ Value n = arith.sub(start, index); -+ -+ // memmove(&array[index + 1], &array[index], n * sizeof(T)) -+ // memref::CopyOp would be nice to use here, but: -+ // 1. It lowers to a quite inefficient library call in the general case -+ // (strides != 1). -+ // 2. It implements memcpy semantics, but we need memmove here. -+ // So we go with a loop instead. -+ auto copyForOp = b.create(zero, n, one); -+ { -+ OpBuilder::InsertionGuard innerGuard(b); -+ b.setInsertionPointToStart(copyForOp.getBody()); -+ Value copyLoopIndex = copyForOp.getBody()->getArgument(0); -+ -+ Value dstIndex = arith.sub(start, copyLoopIndex); -+ Value srcIndex = arith.sub(dstIndex, one); -+ storeMemrefElements(b, outputMemrefs, dstIndex, -+ loadMemrefElements(b, outputMemrefs, srcIndex)); -+ } -+ // array[index] = pivot; -+ storeMemrefElements(b, outputMemrefs, index, pivots); -+ } -+} -+ -+void emitMerge(ImplicitLocOpBuilder& b, Value lo, Value mid, Value hi, -+ ValueRange readBufs, ValueRange writeBufs, -+ mlir::Region& comparator) { -+ ArithBuilder arith(b, b.getLoc()); -+ // The while loop runs until we reach the end of either interval. It has three -+ // loop-carried variables: -+ // 1. current output index -+ // 2. current read index for interval 1 -+ // 3. current read index for interval 2 -+ SmallVector whileArgTypes{lo.getType(), lo.getType(), mid.getType()}; -+ SmallVector whileInitArgs{lo, lo, mid}; -+ SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); -+ -+ // while( -+ auto whileOp = b.create(whileArgTypes, whileInitArgs); -+ { -+ OpBuilder::InsertionGuard guard(b); -+ { -+ Block* before = -+ b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); -+ Value i0 = before->getArgument(1), i1 = before->getArgument(2); -+ b.setInsertionPointToEnd(before); -+ -+ // i0 < mid && i1 < hi) { -+ Value inbounds0 = arith.slt(i0, mid); -+ Value inbounds1 = arith.slt(i1, hi); -+ -+ b.create(arith._and(inbounds0, inbounds1), -+ before->getArguments()); -+ } -+ -+ { -+ Block* after = -+ b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); -+ Value iOut = after->getArgument(0), i0 = after->getArgument(1), -+ i1 = after->getArgument(2); -+ b.setInsertionPointToEnd(after); -+ -+ // auto vals0 = readBufs[i0], vals1 = readBufs[i1]; -+ SmallVector vals0 = loadMemrefElements(b, readBufs, i0); -+ SmallVector vals1 = loadMemrefElements(b, readBufs, i1); -+ -+ // writeBufs[iOut] = comparator(vals1, vals0) -+ // ? readBufs[i1++] : readBufs[i0++]; -+ Value cmp = emitComparison(b, vals1, vals0, comparator); -+ SmallVector pickedVals; -+ for (auto [val0, val1] : llvm::zip(vals0, vals1)) { -+ pickedVals.push_back(b.create(cmp, val1, val0)); -+ } -+ storeMemrefElements(b, writeBufs, iOut, pickedVals); -+ -+ Value one = b.create(1); -+ Value nexti0 = b.create(cmp, i0, arith.add(i0, one)); -+ Value nexti1 = b.create(cmp, arith.add(i1, one), i1); -+ // ++iOut; -+ Value nextIOut = b.create(iOut, one); -+ b.create(ValueRange{nextIOut, nexti0, nexti1}); -+ } -+ } -+ -+ // At this point, exactly one of the input ranges will have leftover elements. -+ Value iOut = whileOp->getResult(0); -+ Value i0 = whileOp->getResult(1); -+ Value i1 = whileOp->getResult(2); -+ -+ // We could use memref::CopyOp here, but typically, there aren't many leftover -+ // elements for randomly shuffled inputs. -+ Value leftoverIn0 = arith.slt(i0, mid); -+ Value start = arith.select(leftoverIn0, i0, i1); -+ Value end = arith.select(leftoverIn0, mid, hi); -+ Value n = arith.sub(end, start); -+ -+ Value zero = b.create(0); -+ Value one = b.create(1); -+ auto forOp = b.create(zero, n, one); -+ b.setInsertionPointToStart(forOp.getBody()); -+ Value copyIndex = forOp.getBody()->getArgument(0); -+ -+ Value srcIndex = arith.add(start, copyIndex); -+ Value dstIndex = arith.add(iOut, copyIndex); -+ storeMemrefElements(b, writeBufs, dstIndex, -+ loadMemrefElements(b, readBufs, srcIndex)); -+} -+ -+// Emits a bottom up merge sort of `inputTensors` in the range [`lo`; `hi`), and -+// writes the results to either `outputs0` or `outputs1`. -+// Returns 0 if the results are in `outputs0`, 1 if they are in `outputs1`. -+// TODO(jreiffers): Consider implementing top-down merge sort. -+Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, -+ int64_t staticSortDimSize, ValueRange inputTensors, -+ ValueRange outputs0, ValueRange outputs1, -+ mlir::Region& comparator) { -+ ArithBuilder arith(b, b.getLoc()); -+ Value size = arith.sub(hi, lo); -+ -+ Value zero = b.create(0); -+ Value insertionSortSize = -+ b.create(kInsertionSortSize); -+ -+ // Run insertion sort on blocks of size kInsertionSortSize. -+ // for (int start = 0; start < size; start += kInsertionSortSize) { -+ { -+ auto forOp = b.create(zero, size, insertionSortSize); -+ OpBuilder::InsertionGuard guard(b); -+ b.setInsertionPointToStart(forOp.getBody()); -+ Value start = forOp.getBody()->getArgument(0); -+ Value end = arith.add( -+ b.create(arith.add(start, insertionSortSize), size), lo); -+ emitInsertionSort(b, start, end, inputTensors, outputs0, comparator); -+ } -+ -+ Value initParity = b.create(0, 1); -+ if (staticSortDimSize >= 0 && staticSortDimSize < kInsertionSortSize) { -+ return initParity; -+ } -+ -+ // The while arguments are: -+ // 1. the current size -+ // 2. the original index of the buffers we're currently reading from -+ // 3. the buffers we're currently reading from -+ // 4. the buffers we're currently writing to. -+ // -+ // 1 gets doubled each iteration, 2 gets negated, 3 and 4 are swapped. -+ // int currentSize = 16; -+ SmallVector whileInitArgs{insertionSortSize, initParity}; -+ // First we read from `outputs0` (initialized by the insertion sort above). -+ llvm::copy(outputs0, std::back_inserter(whileInitArgs)); -+ llvm::copy(outputs1, std::back_inserter(whileInitArgs)); -+ -+ SmallVector whileArgTypes; -+ for (auto val : whileInitArgs) whileArgTypes.push_back(val.getType()); -+ -+ SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); -+ -+ // while ( -+ auto whileOp = b.create(whileArgTypes, whileInitArgs); -+ OpBuilder::InsertionGuard guard(b); -+ -+ // currentSize < totalSize) -+ { -+ Block* before = -+ b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); -+ Value currentSize = before->getArgument(0); -+ b.setInsertionPointToEnd(before); -+ b.create(arith.slt(currentSize, size), -+ before->getArguments()); -+ } -+ -+ size_t numArgs = inputTensors.size(); -+ // { -+ { -+ Block* after = -+ b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); -+ -+ Value currentSize = after->getArgument(0); -+ Value parity = after->getArgument(1); -+ auto readBufs = after->getArguments().drop_front(2).take_front(numArgs); -+ auto writeBufs = after->getArguments().take_back(numArgs); -+ -+ Value twoCurrentSize = arith.add(currentSize, currentSize); -+ -+ // for (int start = 0; start < size; start += 2*currentSize) { -+ { -+ auto forOp = b.create(zero, size, twoCurrentSize); -+ b.setInsertionPointToStart(forOp.getBody()); -+ Value start = forOp.getBody()->getArgument(0); -+ -+ Value mid = b.create(size, arith.add(start, currentSize)); -+ Value end = b.create(size, arith.add(start, twoCurrentSize)); -+ emitMerge(b, start, mid, end, readBufs, writeBufs, comparator); -+ b.setInsertionPointAfter(forOp); -+ } -+ // } -+ -+ // parity = !parity; -+ Value one = b.create(1, 1); -+ Value notParity = arith.sub(one, parity); -+ // currentSize *= 2; -+ SmallVector nextWhileArgs{twoCurrentSize, notParity}; -+ llvm::copy(writeBufs, std::back_inserter(nextWhileArgs)); -+ llvm::copy(readBufs, std::back_inserter(nextWhileArgs)); -+ b.create(nextWhileArgs); -+ } -+ // } -+ -+ // The result is the parity bit. -+ return whileOp.getResults().drop_front(1).front(); -+} -+ -+// Helper struct for extracting 1d slices from tensors and memrefs. -+struct Slicer { -+ Slicer(OpBuilder& b, uint64_t sortDim, Value sortDimSize, ValueRange ivs) -+ : sizes(ivs.size() + 1, b.getI64IntegerAttr(1)), -+ strides(ivs.size() + 1, b.getI64IntegerAttr(1)) { -+ sizes[sortDim] = sortDimSize; -+ for (size_t i = 0; i < ivs.size() + 1; ++i) { -+ if (i == sortDim) { -+ offsets.push_back(b.getI64IntegerAttr(0)); -+ } else { -+ offsets.push_back(ivs[i - static_cast(i > sortDim)]); -+ } -+ } -+ } -+ -+ RankedTensorType toSlicedType(RankedTensorType sourceType) { -+ return tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( -+ /*resultRank=*/1, sourceType, offsets, sizes, strides); -+ } -+ -+ MemRefType toSlicedType(MemRefType sourceType) { -+ return mlir::cast(memref::SubViewOp::inferRankReducedResultType( -+ {ShapedType::kDynamic} /*1D output*/, sourceType, offsets, sizes, -+ strides)); -+ } -+ -+ template -+ Value slice(ImplicitLocOpBuilder& b, Value input) { -+ Ty ty = mlir::cast(input.getType()); -+ return b.create(toSlicedType(ty), input, offsets, sizes, strides) -+ .getResult(); -+ } -+ -+ Value apply(ImplicitLocOpBuilder& b, Value input) { -+ Type inTy = input.getType(); -+ if (mlir::isa(inTy)) { -+ return slice(b, input); -+ } -+ assert(mlir::isa(inTy)); -+ return slice(b, input); -+ } -+ -+ SmallVector offsets; -+ SmallVector sizes; -+ SmallVector strides; -+}; -+ -+SmallVector sliceMemrefsOrTensors(ImplicitLocOpBuilder& b, -+ SmallVector& ivs, -+ Value sortDimSize, -+ ValueRange memrefsOrTensors, -+ SortOp op) { -+ if (ivs.empty()) return memrefsOrTensors; -+ -+ SmallVector outputs; -+ Slicer slicer(b, op.getDimension(), sortDimSize, ivs); -+ // Create subviews/slices. -+ for (Value out : memrefsOrTensors) { -+ outputs.push_back(slicer.apply(b, out)); -+ } -+ -+ return outputs; -+} -+ -+struct SortOpPattern : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(SortOp op, -+ PatternRewriter& rewriter) const override { -+ ImplicitLocOpBuilder b(op.getLoc(), rewriter); -+ -+ // Note: the output memrefs aren't necessarily the ones that we return, -+ SmallVector outputMemrefs; -+ SmallVector scratchMemrefs; -+ -+ Value firstOperand = op.getOperands().front(); -+ auto firstOperandType = mlir::cast(firstOperand.getType()); -+ int64_t inputRank = firstOperandType.getRank(); -+ -+ Value sortDimSize = b.createOrFold( -+ firstOperand, b.create(op.getDimension())); -+ int64_t staticSortDimSize = firstOperandType.getDimSize(op.getDimension()); -+ -+ SmallVector dynamicDims; -+ for (int i = 0; i < inputRank; ++i) { -+ if (!firstOperandType.isDynamicDim(i)) continue; -+ Value index = b.create(i); -+ Value dimOp = b.create(firstOperand, index); -+ dynamicDims.push_back(dimOp); -+ } -+ -+ // Allocate output and scratch memrefs. If the size of the sort dimension is -+ // statically known to be <= kInsertionSortSize, `scratchMemrefs` are unused -+ // and will be cleaned up later. -+ for (auto input : op.getOperands()) { -+ auto inputType = mlir::cast(input.getType()); -+ auto memRefType = -+ MemRefType::get(inputType.getShape(), inputType.getElementType()); -+ -+ outputMemrefs.push_back( -+ b.create(memRefType, dynamicDims)); -+ scratchMemrefs.push_back( -+ b.create(memRefType, dynamicDims)); -+ } -+ -+ b.setInsertionPoint(op); -+ Value zero = b.create(0); -+ Value one = b.create(1); -+ -+ Value forInitArg = b.create(0, 1); -+ SmallVector forOps; -+ SmallVector ivs; -+ forOps.reserve(inputRank - 1); -+ ivs.reserve(inputRank - 1); -+ for (int64_t i = 0; i < inputRank; ++i) { -+ if (i != static_cast(op.getDimension())) { -+ Value dim = b.create(i); -+ Value ub = b.create(firstOperand, dim); -+ scf::ForOp& forOp = forOps.emplace_back( -+ b.create(zero, ub, one, ValueRange{forInitArg})); -+ ivs.push_back(forOp.getInductionVar()); -+ b.setInsertionPointToStart(&forOp.getRegion().front()); -+ } -+ } -+ SmallVector inputs = -+ sliceMemrefsOrTensors(b, ivs, sortDimSize, op.getOperands(), op); -+ SmallVector outputs = -+ sliceMemrefsOrTensors(b, ivs, sortDimSize, outputMemrefs, op); -+ SmallVector scratches = -+ sliceMemrefsOrTensors(b, ivs, sortDimSize, scratchMemrefs, op); -+ -+ Value parity = -+ emitBottomUpMergeSort(b, zero, sortDimSize, staticSortDimSize, inputs, -+ outputs, scratches, op.getRegion()); -+ -+ // Pass the parity bit through the for loops. -+ for (auto i = static_cast(forOps.size() - 1); i >= 0; --i) { -+ b.setInsertionPointToEnd(&forOps[i].getRegion().front()); -+ b.create(ValueRange{parity}); -+ parity = forOps[i]->getResult(0); -+ } -+ b.setInsertionPoint(op); -+ -+ SmallVector outputTensors; -+ for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) { -+ outputTensors.push_back(b.create( -+ b.create(parity, out1, out0), /*restrict=*/true)); -+ } -+ -+ rewriter.replaceOp(op, outputTensors); -+ return success(); -+ } -+}; -+ -+struct LegalizeSortPass -+ : public impl::HloLegalizeSortPassBase { -+ // Perform the lowering to MLIR control flow. -+ void runOnOperation() override { -+ func::FuncOp f = getOperation(); -+ MLIRContext* ctx = f.getContext(); -+ -+ RewritePatternSet patterns(ctx); -+ patterns.add(ctx); -+ -+ mlir::ConversionTarget target(*ctx); -+ target.markUnknownOpDynamicallyLegal([](Operation*) { return true; }); -+ target.addIllegalOp(); -+ -+ if (failed(applyPartialConversion(f, target, std::move(patterns)))) { -+ signalPassFailure(); -+ } -+ } -+}; -+ -+} // namespace -+} // namespace mhlo -+} // namespace mlir -+ -+std::unique_ptr> -+mlir::mhlo::createLegalizeSortPass() { -+ return std::make_unique(); -+} -diff --git a/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc b/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc -new file mode 100644 -index 00000000..be752397 ---- /dev/null -+++ b/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc -@@ -0,0 +1,243 @@ -+/* Copyright 2019 The OpenXLA Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+// This file implements logic for lowering MHLO dialect to Standard dialect. -+ -+#include -+#include -+#include -+ -+#include "mhlo/IR/hlo_ops.h" -+#include "mhlo/transforms/passes.h" -+#include "mhlo/transforms/rewriters.h" -+#include "mlir/Dialect/Arith/IR/Arith.h" -+#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/Dialect/Math/IR/Math.h" -+#include "mlir/IR/BuiltinOps.h" -+#include "mlir/Pass/Pass.h" -+#include "mlir/Support/LLVM.h" -+#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -+ -+namespace mlir { -+namespace { -+#include "legalize_to_standard/generated_legalize_to_standard.inc" -+} // end anonymous namespace -+namespace mhlo { -+ -+#define GEN_PASS_DEF_LEGALIZETOSTANDARDPASS -+#include "mhlo/transforms/mhlo_passes.h.inc" -+ -+namespace { -+ -+class CompareIConvert : public OpRewritePattern { -+ public: -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(mhlo::CompareOp op, -+ PatternRewriter &rewriter) const override { -+ auto lhs = op.getLhs(); -+ auto rhs = op.getRhs(); -+ auto lhsType = mlir::cast(lhs.getType()); -+ auto rhsType = mlir::cast(rhs.getType()); -+ -+ // Broadcasting not supported by this rewrite. -+ if (lhsType.getShape() != rhsType.getShape()) return failure(); -+ -+ if (!lhsType.getElementType().isSignlessInteger() || -+ !rhsType.getElementType().isSignlessInteger()) -+ return failure(); -+ -+ std::optional comparePredicate = std::nullopt; -+ switch (op.getComparisonDirection()) { -+ case ComparisonDirection::EQ: -+ comparePredicate = arith::CmpIPredicate::eq; -+ break; -+ case ComparisonDirection::NE: -+ comparePredicate = arith::CmpIPredicate::ne; -+ break; -+ case ComparisonDirection::LT: -+ comparePredicate = arith::CmpIPredicate::slt; -+ break; -+ case ComparisonDirection::LE: -+ comparePredicate = arith::CmpIPredicate::sle; -+ break; -+ case ComparisonDirection::GT: -+ comparePredicate = arith::CmpIPredicate::sgt; -+ break; -+ case ComparisonDirection::GE: -+ comparePredicate = arith::CmpIPredicate::sge; -+ break; -+ } -+ -+ if (!comparePredicate.has_value()) return failure(); -+ -+ rewriter.replaceOpWithNewOp(op, comparePredicate.value(), -+ lhs, rhs); -+ return success(); -+ } -+}; -+ -+class CompareFConvert : public OpRewritePattern { -+ public: -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(mhlo::CompareOp op, -+ PatternRewriter &rewriter) const override { -+ auto lhs = op.getLhs(); -+ auto rhs = op.getRhs(); -+ auto lhsType = mlir::cast(lhs.getType()); -+ auto rhsType = mlir::cast(rhs.getType()); -+ -+ // Broadcasting not supported by this rewrite. -+ if (lhsType.getShape() != rhsType.getShape()) return failure(); -+ -+ if (!mlir::isa(lhsType.getElementType()) || -+ !mlir::isa(rhsType.getElementType())) -+ return failure(); -+ -+ std::optional comparePredicate = std::nullopt; -+ switch (op.getComparisonDirection()) { -+ case ComparisonDirection::EQ: -+ comparePredicate = arith::CmpFPredicate::OEQ; -+ break; -+ case ComparisonDirection::NE: -+ comparePredicate = arith::CmpFPredicate::UNE; -+ break; -+ case ComparisonDirection::LT: -+ comparePredicate = arith::CmpFPredicate::OLT; -+ break; -+ case ComparisonDirection::LE: -+ comparePredicate = arith::CmpFPredicate::OLE; -+ break; -+ case ComparisonDirection::GT: -+ comparePredicate = arith::CmpFPredicate::OGT; -+ break; -+ case ComparisonDirection::GE: -+ comparePredicate = arith::CmpFPredicate::OGE; -+ break; -+ } -+ -+ if (!comparePredicate.has_value()) return failure(); -+ -+ rewriter.replaceOpWithNewOp(op, comparePredicate.value(), -+ lhs, rhs); -+ return success(); -+ } -+}; -+ -+// Replace IotaOp with an integer constant. A ConvertOp is added to -+// convert the integer constant to iota result type. For complex types, the real -+// part is replaced with the generated constant and the imaginary part is -+// replaced with zero tensor. -+class ConvertIotaOp : public OpRewritePattern { -+ public: -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(mhlo::IotaOp op, -+ PatternRewriter &rewriter) const override { -+ auto outputType = mlir::cast(op.getType()); -+ auto outputSize = outputType.getNumElements(); -+ auto dimension = op.getIotaDimension(); -+ auto maxDimSize = outputType.getDimSize(dimension); -+ -+ auto elementType = outputType.getElementType(); -+ int bitwidth; -+ -+ auto complexTy = mlir::dyn_cast(elementType); -+ Type intOrFloatTy = elementType; -+ if (complexTy) intOrFloatTy = complexTy.getElementType(); -+ -+ bitwidth = intOrFloatTy.getIntOrFloatBitWidth(); -+ llvm::SmallVector values; -+ values.reserve(outputSize); -+ -+ int64_t increaseStride = outputSize; -+ for (uint64_t i = 0; i <= dimension; i++) { -+ increaseStride /= outputType.getDimSize(i); -+ } -+ -+ int64_t currentValue = 0; -+ for (int i = 0; i < outputSize; i++) { -+ int64_t value = (currentValue / increaseStride) % maxDimSize; -+ values.push_back(APInt(bitwidth, value)); -+ ++currentValue; -+ } -+ -+ auto intShapeType = RankedTensorType::get( -+ outputType.getShape(), -+ IntegerType::get(rewriter.getContext(), bitwidth)); -+ auto loc = op.getLoc(); -+ auto integerConst = rewriter.create( -+ loc, DenseIntElementsAttr::get(intShapeType, values)); -+ -+ auto intOrFloatShapeTy = -+ RankedTensorType::get(outputType.getShape(), intOrFloatTy); -+ -+ auto iotaConst = -+ rewriter.create(loc, intOrFloatShapeTy, integerConst); -+ -+ // For int/float types we are done, replace op and return. -+ if (!complexTy) { -+ rewriter.replaceOp(op, iotaConst.getResult()); -+ return success(); -+ } -+ -+ // For complex types, generate a constant tensor of zeroes for the imaginary -+ // part and use iota_const for real part. -+ auto zeroes = rewriter.create( -+ loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0))); -+ auto imagZeroes = -+ rewriter.create(loc, intOrFloatShapeTy, zeroes); -+ rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); -+ return success(); -+ } -+}; -+ -+} // end anonymous namespace -+ -+namespace { -+struct LegalizeToStandardPass -+ : public impl::LegalizeToStandardPassBase { -+ void getDependentDialects(DialectRegistry ®istry) const override { -+ registry -+ .insert(); -+ } -+ -+ /// Perform the lowering to Standard dialect. -+ void runOnOperation() override; -+}; -+} // end anonymous namespace -+ -+std::unique_ptr> -+createLegalizeToStdPass() { -+ return std::make_unique(); -+} -+ -+void populateMhloToStdPatterns(RewritePatternSet *patterns, -+ mlir::MLIRContext *ctx) { -+ mlir::populateWithGenerated(*patterns); -+ patterns->add(ctx); -+} -+ -+/// Perform the lowering to standard dialect. -+void LegalizeToStandardPass::runOnOperation() { -+ RewritePatternSet patterns(&getContext()); -+ mlir::mhlo::populateMhloToStdPatterns(&patterns, &getContext()); -+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) -+ return signalPassFailure(); -+} -+ -+} // end namespace mhlo -+} // end namespace mlir -diff --git a/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td b/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td -new file mode 100644 -index 00000000..f4d24608 ---- /dev/null -+++ b/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td -@@ -0,0 +1,92 @@ -+/* Copyright 2019 The OpenXLA Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+// This is the legalization pattern definition file for MHLO to StandardOps. -+ -+include "mlir/IR/OpBase.td" -+include "mlir/Dialect/Arith/IR/ArithOps.td" -+include "mlir/Dialect/Math/IR/MathOps.td" -+include "mlir/Dialect/Func/IR/FuncOps.td" -+include "mhlo/IR/hlo_ops.td" -+ -+//===----------------------------------------------------------------------===// -+// Nullary op patterns. -+//===----------------------------------------------------------------------===// -+ -+def : Pat<(MHLO_ConstantOp ElementsAttr:$value), -+ (Arith_ConstantOp $value)>; -+ -+//===----------------------------------------------------------------------===// -+// Binary op patterns. -+//===----------------------------------------------------------------------===// -+ -+def IsSameSizePred : CPred< -+ "cast($0.getType()).getShape() " -+ "== cast($1.getType()).getShape()">; -+def IsSameSizeConstraint : Constraint; -+def createFastMathNone : NativeCodeCall< -+ "::mlir::arith::FastMathFlagsAttr::get(" -+ "$_builder.getContext(), ::mlir::arith::FastMathFlags::none" -+ ")">; -+def createOverflowNone : NativeCodeCall< -+ "::mlir::arith::IntegerOverflowFlagsAttr::get(" -+ "$_builder.getContext(), ::mlir::arith::IntegerOverflowFlags::none" -+ ")">; -+ -+ -+// Unary Lowering Patterns. -+def : Pat<(MHLO_CeilOp MHLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; -+ -+// Binary Lowering Patterns. -+def : Pat<(MHLO_AndOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), -+ (Arith_AndIOp $l, $r), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_OrOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), -+ (Arith_OrIOp $l, $r), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_AddOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), -+ (Arith_AddFOp $l, $r, (createFastMathNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_SubtractOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), -+ (Arith_SubFOp $l, $r, (createFastMathNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_MulOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), -+ (Arith_MulFOp $l, $r, (createFastMathNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_DivOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), -+ (Arith_DivFOp $l, $r, (createFastMathNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_RemOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), -+ (Arith_RemFOp $l, $r, (createFastMathNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_AddOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), -+ (Arith_AddIOp $l, $r, (createOverflowNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_SubtractOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), -+ (Arith_SubIOp $l, $r, (createOverflowNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_MulOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), -+ (Arith_MulIOp $l, $r, (createOverflowNone )), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_DivOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), -+ (Arith_DivSIOp $l, $r), -+ [(IsSameSizeConstraint $l, $r)]>; -+def : Pat<(MHLO_RemOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), -+ (Arith_RemSIOp $l, $r), -+ [(IsSameSizeConstraint $l, $r)]>; -+ -+def : Pat<(MHLO_SelectOp $pred, $tv, $fv), -+ (SelectOp $pred, $tv, $fv), -+ [(IsSameSizeConstraint $pred, $tv), (IsSameSizeConstraint $tv, $fv)]>; -diff --git a/mhlo/transforms/mhlo_passes.td b/mhlo/transforms/mhlo_passes.td -index 853531c1..378f8944 100644 ---- a/mhlo/transforms/mhlo_passes.td -+++ b/mhlo/transforms/mhlo_passes.td -@@ -15,6 +15,25 @@ limitations under the License. - - include "mlir/Pass/PassBase.td" - -+def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "func::FuncOp"> { -+ let summary = "Legalize from MHLO control flow to SCF control flow."; -+ let constructor = "createLegalizeControlFlowPass()"; -+ let dependentDialects = ["scf::SCFDialect", "tensor::TensorDialect"]; -+} -+ -+def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "func::FuncOp"> { -+ let summary = "Legalize from MHLO dialect to standard dialect."; -+ let constructor = "createLegalizeToStdPass()"; -+} -+ -+def HloLegalizeSortPass : Pass<"hlo-legalize-sort", "func::FuncOp"> { -+ let summary = "Legalize from MHLO sort to SCF control flow."; -+ let constructor = "createLegalizeSortPass()"; -+ let dependentDialects = ["arith::ArithDialect", -+ "bufferization::BufferizationDialect", -+ "scf::SCFDialect", "tensor::TensorDialect"]; -+} -+ - def ChloLegalizeToHighLevelMhloPass : Pass<"chlo-legalize-to-high-level-mhlo", "func::FuncOp"> { - let summary = "Legalize CHLO's with XLA counterparts, like TopK and Erf."; - let description = [{ -diff --git a/mhlo/transforms/passes.h b/mhlo/transforms/passes.h -index 3d2aa3b3..3f03b2df 100644 ---- a/mhlo/transforms/passes.h -+++ b/mhlo/transforms/passes.h -@@ -37,6 +37,10 @@ namespace mhlo { - #define GEN_PASS_DECL - #include "mhlo/transforms/mhlo_passes.h.inc" - -+std::unique_ptr> createLegalizeControlFlowPass(); -+std::unique_ptr> createLegalizeSortPass(); -+std::unique_ptr> createLegalizeToStdPass(); -+ - /// Lowers from HLO dialect to Arithmetic dialect. - std::unique_ptr> createLegalizeToArithmeticPass(); - --- -2.34.1 - diff --git a/mlir/patches/mhlo-remove-shardy.patch b/mlir/patches/mhlo-remove-shardy.patch deleted file mode 100644 index f78200bdab..0000000000 --- a/mlir/patches/mhlo-remove-shardy.patch +++ /dev/null @@ -1,132 +0,0 @@ -From 70172e8399383d6c1964d73a2d20cba3c55a3279 Mon Sep 17 00:00:00 2001 -From: paul0403 -Date: Thu, 29 May 2025 10:06:35 -0400 -Subject: [PATCH] remove shardy dependency - ---- - bindings/c/CMakeLists.txt | 1 - - stablehlo_ext/CMakeLists.txt | 1 + - stablehlo_ext/analysis/CMakeLists.txt | 3 ++- - stablehlo_ext/transforms/CMakeLists.txt | 7 ++++++- - stablehlo_ext/transforms/stablehlo_refine_shapes.cpp | 3 --- - tests/lit.cfg.py | 1 + - tools/mlir-hlo-opt/mlir-hlo-opt.cc | 2 -- - 7 files changed, 10 insertions(+), 8 deletions(-) - -diff --git a/bindings/c/CMakeLists.txt b/bindings/c/CMakeLists.txt -index fd2a5c2c..53d916d5 100644 ---- a/bindings/c/CMakeLists.txt -+++ b/bindings/c/CMakeLists.txt -@@ -10,7 +10,6 @@ add_mlir_public_c_api_library(MLIRHLOCAPIDialects - MhloPasses - MhloToArithmeticConversion - MhloToMemrefConversion -- MhloToStandard - MhloToLinalg - MhloToStablehlo - StablehloToMhlo -diff --git a/stablehlo_ext/CMakeLists.txt b/stablehlo_ext/CMakeLists.txt -index 3e55a89d..e8d318f1 100644 ---- a/stablehlo_ext/CMakeLists.txt -+++ b/stablehlo_ext/CMakeLists.txt -@@ -12,5 +12,6 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - -+add_subdirectory(analysis) - add_subdirectory(IR) - add_subdirectory(transforms) -diff --git a/stablehlo_ext/analysis/CMakeLists.txt b/stablehlo_ext/analysis/CMakeLists.txt -index 726d340d..0c0259b8 100644 ---- a/stablehlo_ext/analysis/CMakeLists.txt -+++ b/stablehlo_ext/analysis/CMakeLists.txt -@@ -1,5 +1,6 @@ - add_mlir_library(MhloAnalysis -- shape_component_analysis.cc -+ shape_component_analysis.cpp -+ PARTIAL_SOURCES_INTENDED - - DEPENDS - mlir-headers -diff --git a/stablehlo_ext/transforms/CMakeLists.txt b/stablehlo_ext/transforms/CMakeLists.txt -index ee58f490..2d7cc22c 100644 ---- a/stablehlo_ext/transforms/CMakeLists.txt -+++ b/stablehlo_ext/transforms/CMakeLists.txt -@@ -20,9 +20,14 @@ add_mlir_dialect_library(StablehloExtensionPasses - PARTIAL_SOURCES_INTENDED - chlo_recompose_ops.cpp - chlo_preserve_high_level_ops.cpp -+ sink_constants_to_control_flow.cpp -+ stablehlo_add_quant_dequant_conv.cpp - stablehlo_canonicalize_dynamism.cpp -+ stablehlo_canonicalize_from_hlo_import.cpp -+ stablehlo_legalize_quant_composite.cpp -+ stablehlo_prepare_for_hlo_export.cpp - stablehlo_refine_shapes.cpp -- sdy_refine_shapes.cpp -+ symbolic_shape_optimization.cpp - - DEPENDS - StablehloExtensionPassesIncGen -diff --git a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -index cabd6a9f..2e64b4ed 100644 ---- a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -+++ b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -@@ -34,7 +34,6 @@ limitations under the License. - #include "stablehlo_ext/IR/base.h" - #include "stablehlo_ext/IR/stablehlo_ops.h" - #include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc --#include "stablehlo_ext/transforms/sdy_refine_shapes.h" - - namespace mlir { - namespace stablehlo_ext { -@@ -154,7 +153,6 @@ struct StablehloRefineShapesPass - patterns->add(context); - patterns->add(context); - patterns->add(context); -- populateSdyShapeRefinementPatterns(patterns, context); - }; - - if (failed(stablehlo::refineEntryFunction(*context, func, -@@ -172,7 +170,6 @@ void populateStablehloExtRefineShapesPatterns(RewritePatternSet *patterns, - patterns->add(context); - patterns->add(context); - patterns->add(context); -- populateSdyShapeRefinementPatterns(patterns, context); - } - - } // namespace stablehlo_ext -diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py -index ab20fbb5..6c61aec5 100644 ---- a/tests/lit.cfg.py -+++ b/tests/lit.cfg.py -@@ -32,6 +32,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) - - # suffixes: A list of file extensions to treat as test files. - config.suffixes = ['.mlir'] -+config.excludes = ['sdy_refine_shapes.mlir'] - - # test_source_root: The root path where tests are located. - config.test_source_root = os.path.dirname(__file__) -diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/tools/mlir-hlo-opt/mlir-hlo-opt.cc -index f018cbdc..b4474850 100644 ---- a/tools/mlir-hlo-opt/mlir-hlo-opt.cc -+++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cc -@@ -20,7 +20,6 @@ limitations under the License. - #include "mlir/InitAllExtensions.h" - #include "mlir/InitAllPasses.h" - #include "mlir/Tools/mlir-opt/MlirOptMain.h" --#include "shardy/dialect/sdy/ir/dialect.h" - #include "stablehlo/dialect/Register.h" - #include "stablehlo_ext/transforms/passes.h" - #include "transforms/gpu_passes.h" -@@ -41,6 +40,5 @@ int main(int argc, char** argv) { - registerAllExtensions(registry); - mhlo::registerAllMhloDialects(registry); - stablehlo::registerAllDialects(registry); -- registry.insert(); - return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); - } --- -2.34.1 - diff --git a/mlir/stablehlo b/mlir/stablehlo new file mode 160000 index 0000000000..f1f035fea3 --- /dev/null +++ b/mlir/stablehlo @@ -0,0 +1 @@ +Subproject commit f1f035fea33dcfdd7c471eb7f39174b344003117 diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index 7e928f3c99..d8f1a80744 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -38,10 +38,10 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - MhloRegisterDialects - StablehloRegister + #MhloRegisterDialects + #StablehloRegister MLIRCatalystTest - ${ALL_MHLO_PASSES} + #${ALL_MHLO_PASSES} ${ENZYME_LIB} CatalystCompilerDriver ) diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index 2bfcd7e134..6047752bce 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -11,7 +11,7 @@ set(LIBS MLIRMBQC MLIRMitigation MLIRIon - MhloRegisterDialects + #MhloRegisterDialects StablehloRegister ) diff --git a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp index 3160ea1919..a98de69942 100644 --- a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp +++ b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp @@ -24,7 +24,6 @@ #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" -#include "mhlo/IR/register.h" #include "stablehlo/dialect/Register.h" int main(int argc, char **argv) @@ -39,7 +38,6 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); - mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index a0c64ce9ea..11a50462ef 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -20,11 +20,13 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - MhloRegisterDialects + #MhloRegisterDialects StablehloRegister + StablehloPasses + StablehloOps MLIRCatalystTest MLIRTestDialect - ${ALL_MHLO_PASSES} + #${ALL_MHLO_PASSES} ) add_mlir_tool(quantum-opt quantum-opt.cpp DEPENDS ${LIBS} SUPPORT_PLUGINS) diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index 7bcf8702bb..b314377d0f 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mhlo/IR/register.h" -#include "mhlo/transforms/passes.h" + #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "stablehlo/dialect/Register.h" -#include "mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/Register.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/integrations/c/StablehloPasses.h" #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" @@ -47,12 +48,13 @@ int main(int argc, char **argv) { mlir::registerAllPasses(); catalyst::registerAllCatalystPasses(); - mlir::mhlo::registerAllMhloPasses(); + //mlir::mhlo::registerAllMhloPasses(); + mlirRegisterAllStablehloPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); test::registerTestDialect(registry); - mlir::mhlo::registerAllMhloDialects(registry); + //mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); mlir::func::registerAllExtensions(registry); registry.insert(); @@ -62,7 +64,8 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + //registry.insert(); + registry.insert(); catalyst::registerBufferizableOpInterfaceExternalModels(registry); catalyst::gradient::registerBufferizableOpInterfaceExternalModels(registry); From 382efcb9a9c34e79f43291361b5978dad24c0a2c Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 18 Jul 2025 16:41:15 -0400 Subject: [PATCH 02/63] . --- .gitmodules | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 6dc71ddc43..9fa2499eb3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,6 +13,3 @@ url = https://github.com/EnzymeAD/Enzyme.git shallow = true ignore = dirty -[submodule "mlir/stablehlo"] - path = mlir/stablehlo - url = https://github.com/openxla/stablehlo.git From 59f4e93e419ec5cb298a1ac059e9f32815a76002 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 18 Jul 2025 16:46:37 -0400 Subject: [PATCH 03/63] no need for an individual `make stablehlo` --- mlir/Makefile | 37 +++---------------------------------- 1 file changed, 3 insertions(+), 34 deletions(-) diff --git a/mlir/Makefile b/mlir/Makefile index d482d0300f..d2f47419af 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -7,7 +7,6 @@ MK_ABSPATH := $(abspath $(lastword $(MAKEFILE_LIST))) MK_DIR := $(dir $(MK_ABSPATH)) DIALECTS_BUILD_DIR ?= $(MK_DIR)/build LLVM_BUILD_DIR ?= $(MK_DIR)/llvm-project/build -STABLEHLO_BUILD_DIR ?= $(MK_DIR)/stablehlo/build ENZYME_BUILD_DIR ?= $(MK_DIR)/Enzyme/build RT_BUILD_DIR ?= $(MK_DIR)/../runtime/build ENABLE_ASAN ?= OFF @@ -43,7 +42,6 @@ help: @echo "Please use \`make ' where is one of" @echo " all to build MLIR, MLIR-HLO and custom Catalyst dialects" @echo " llvm to build MLIR enabling Python bindings" - @echo " stablehlo to build stablehlo" @echo " enzyme to build Enzyme" @echo " dialects to build custom Catalyst MLIR dialects" @echo " test to run the Catalyst MLIR dialects test suite" @@ -52,7 +50,7 @@ help: @echo " format [version=?] to apply C++ formatter; use with 'version={version}' to run clang-format-{version} instead of clang-format" .PHONY: all -all: llvm stablehlo enzyme dialects plugin +all: llvm enzyme dialects plugin .PHONY: llvm llvm: @@ -84,29 +82,6 @@ llvm: LIT_FILTER_OUT="Bytecode|tosa-to-tensor|execution_engine" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS) -.PHONY: stablehlo -stablehlo: - @echo "build stablehlo" - - cmake -G Ninja -S stablehlo -B $(STABLEHLO_BUILD_DIR) \ - -DSTABLEHLO_ENABLE_LLD=ON \ - -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ - -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ - -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - -DSTABLEHLO_ENABLE_SANITIZER=address \ - -DCMAKE_C_COMPILER=$(C_COMPILER) \ - -DCMAKE_CXX_COMPILER=$(CXX_COMPILER) \ - -DCMAKE_C_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ - -DCMAKE_CXX_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ - -DCMAKE_EXE_LINKER_FLAGS=$(USE_SANITIZER_FLAGS) \ - -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) - - cmake --build $(STABLEHLO_BUILD_DIR) - #ninja check-stablehlo-tests - .PHONY: enzyme enzyme: TARGET_FILE := $(MK_DIR)/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp enzyme: PATCH_FILE := $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch @@ -165,8 +140,6 @@ dialects: -DEnzyme_DIR=$(ENZYME_BUILD_DIR) \ -DENZYME_SRC_DIR=$(MK_DIR)/Enzyme \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ - -DSTABLEHLO_SRC_DIR=$(MK_DIR)/stablehlo \ - -DSTABLEHLO_BINARY_DIR=$(STABLEHLO_BUILD_DIR)/bin \ -DRUNTIME_LIB_DIR=$(RT_BUILD_DIR)/lib \ -DMLIR_LIB_DIR=$(LLVM_BUILD_DIR)/lib \ -DCMAKE_C_COMPILER=$(C_COMPILER) \ @@ -186,8 +159,8 @@ test: @echo "test the Catalyst MLIR dialects test suite" cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects -.PHONY: clean clean-dialects clean-enzyme clean-stablehlo clean-plugin -clean: clean-dialects clean-llvm clean-stablehlo clean-enzyme clean-plugin +.PHONY: clean clean-dialects clean-enzyme clean-plugin +clean: clean-dialects clean-llvm clean-enzyme clean-plugin clean-dialects: @echo "clean catalyst dialect build files" @@ -197,10 +170,6 @@ clean-llvm: @echo "clean llvm/mlir build files" rm -rf $(LLVM_BUILD_DIR) -clean-stablehlo: - @echo "clean stablehlo dialect build files" - rm -rf $(STABLEHLO_BUILD_DIR) - clean-enzyme: @echo "clean enzyme build files" rm -rf $(ENZYME_BUILD_DIR) From 730c08204348f8deaa42e3c4007326dc0ef99ae0 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 10:33:59 -0400 Subject: [PATCH 04/63] `make dialects` can build stablehlo as embedded as of this commit :yay: --- mlir/CMakeLists.txt | 30 ++----------------- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 4 +-- .../Catalyst/Transforms/scatter_lowering.cpp | 2 -- mlir/lib/Driver/CMakeLists.txt | 4 +-- mlir/tools/catalyst-cli/CMakeLists.txt | 4 +-- mlir/tools/quantum-lsp-server/CMakeLists.txt | 1 - mlir/tools/quantum-opt/CMakeLists.txt | 3 +- mlir/tools/quantum-opt/quantum-opt.cpp | 3 -- 8 files changed, 7 insertions(+), 44 deletions(-) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index e305486d23..1fd05b4bbb 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -4,7 +4,8 @@ project(Catalyst LANGUAGES CXX C) set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 20 CACHE STRING "C++ standard to conform to") +# stablehlo is still on cpp17 +set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to") set(CMAKE_CXX_STANDARD_REQUIRED ON) # Required so as not to always use the cached option from the mlir build. @@ -29,16 +30,9 @@ endif() ######################### find_package(MLIR REQUIRED CONFIG) -#find_package(STABLEHLO REQUIRED CONFIG) -# add_subdirectory(llvm-project/mlir/cmake/modules) - - -message("hi, stable hlo src dir: ${STABLEHLO_SRC_DIR}") -#include_directories(PUBLIC ${STABLEHLO_SRC_DIR}) message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -#message(STATUS "Using STABLEHLOConfig.cmake in: ${STABLEHLO_DIR}") set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) @@ -46,22 +40,9 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) # Taken from mlir-hlo/mhlo/transforms/CMakeLists.txt. # Unfortunately, AllMhloPasses doesn't appear to be exported. -set(ALL_MHLO_PASSES - # ChloPasses - # MhloPasses - # StablehloPasses - # MhloToArithmeticConversion - # MhloToMemrefConversion - # HloToLinalgUtils - # MhloToLinalg - # MhloToStablehlo - # StablehloToMhlo - # StablehloPasses -) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -# list(APPEND CMAKE_MODULE_PATH "${STABLEHLO_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") # Policy CMP0175 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. @@ -91,9 +72,6 @@ if(QUANTUM_ENABLE_BINDINGS_PYTHON) mlir_configure_python_dev_packages() endif() -# list(GET STABLEHLO_INCLUDE_DIRS 1 MLIRHLO_DIR) -# list(GET STABLEHLO_INCLUDE_DIRS 2 MLIRHLO_BUILD_DIR) - set(CATALYST_MAIN_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include) set(CATALYST_GEN_INCLUDE_DIR ${PROJECT_BINARY_DIR}/include) set(CATALYST_LIB_DIR ${PROJECT_BINARY_DIR}) @@ -101,9 +79,6 @@ set(CATALYST_LIB_DIR ${PROJECT_BINARY_DIR}) include_directories(SYSTEM ${LLVM_INCLUDE_DIRS} ${MLIR_INCLUDE_DIRS} - #${STABLEHLO_INCLUDE_DIRS} - #${MLIRHLO_DIR}/stablehlo - #${MLIRHLO_BUILD_DIR}/stablehlo ) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) @@ -219,4 +194,5 @@ add_subdirectory(cmake/modules) add_subdirectory(test) unset(LLVM_USE_LINKER) +set(STABLEHLO_BUILD_EMBEDDED ON) add_subdirectory(stablehlo) diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index bdffe06b24..7c4809bc11 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -41,9 +41,7 @@ set(LIBS set(DEPENDS MLIRCatalystPassIncGen StablehloBaseIncGen - #StablehloBaseIncGen - # StablehloPasses - # StablehloOps + StablehloOpsIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) diff --git a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp b/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp index 2a95de11c9..8e5cca4989 100644 --- a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp +++ b/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp @@ -18,8 +18,6 @@ #include "llvm/Support/Debug.h" -// #include "mhlo/IR/hlo_ops.h" -// #include "mhlo/transforms/passes.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 746ddfdcb7..b979922561 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -40,11 +40,10 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - #MhloRegisterDialects StablehloRegister MLIRCatalystTest - #${ALL_MHLO_PASSES} ${ENZYME_LIB} + StablehloCAPI ) add_mlir_library(CatalystCompilerDriver @@ -57,5 +56,4 @@ add_mlir_library(CatalystCompilerDriver LINK_LIBS PRIVATE ${LIBS} - #StablehloBase ) diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index d8f1a80744..0a23884594 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -38,10 +38,8 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - #MhloRegisterDialects - #StablehloRegister + StablehloRegister MLIRCatalystTest - #${ALL_MHLO_PASSES} ${ENZYME_LIB} CatalystCompilerDriver ) diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index 6047752bce..d0438c26c1 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -11,7 +11,6 @@ set(LIBS MLIRMBQC MLIRMitigation MLIRIon - #MhloRegisterDialects StablehloRegister ) diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 11a50462ef..1d3eea1564 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -20,13 +20,12 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - #MhloRegisterDialects StablehloRegister StablehloPasses StablehloOps + StablehloCAPI MLIRCatalystTest MLIRTestDialect - #${ALL_MHLO_PASSES} ) add_mlir_tool(quantum-opt quantum-opt.cpp DEPENDS ${LIBS} SUPPORT_PLUGINS) diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index b314377d0f..8b526e363d 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -48,13 +48,11 @@ int main(int argc, char **argv) { mlir::registerAllPasses(); catalyst::registerAllCatalystPasses(); - //mlir::mhlo::registerAllMhloPasses(); mlirRegisterAllStablehloPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); test::registerTestDialect(registry); - //mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); mlir::func::registerAllExtensions(registry); registry.insert(); @@ -64,7 +62,6 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); registry.insert(); - //registry.insert(); registry.insert(); catalyst::registerBufferizableOpInterfaceExternalModels(registry); From c251be7908f12af3fb408df47b397c32133a1b86 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 10:41:11 -0400 Subject: [PATCH 05/63] hlo custom call lit test update --- mlir/test/Catalyst/HloCustomCallsTest.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Catalyst/HloCustomCallsTest.mlir b/mlir/test/Catalyst/HloCustomCallsTest.mlir index 019dc7a279..4804478574 100644 --- a/mlir/test/Catalyst/HloCustomCallsTest.mlir +++ b/mlir/test/Catalyst/HloCustomCallsTest.mlir @@ -19,6 +19,6 @@ func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { // CHECK: %cst_0 = arith.constant dense<3> : tensor // CHECK: %0 = catalyst.custom_call fn("lapack_dgesdd_ffi") (%cst, %cst_0, %cst_0, %arg0) : (tensor, tensor, tensor, tensor<3x3xf64>) -> tensor<3x3xf64> // CHECK: return %0 : tensor<3x3xf64> - %0 = mhlo.custom_call @lapack_dgesdd_ffi(%arg0) {api_version = 2 : i32, backend_config = "", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#mhlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>) -> tensor<3x3xf64> + %0 = stablehlo.custom_call @lapack_dgesdd_ffi(%arg0) {api_version = 2 : i32, backend_config = "", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>) -> tensor<3x3xf64> return %0 : tensor<3x3xf64> } From 84ed325fccafcfdadde261dd0a4e25383c21f512 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 10:42:33 -0400 Subject: [PATCH 06/63] scatter lit test update --- mlir/test/Catalyst/ScatterTest.mlir | 68 ++++++++++++++--------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/mlir/test/Catalyst/ScatterTest.mlir b/mlir/test/Catalyst/ScatterTest.mlir index 9ce56aa949..8266678880 100644 --- a/mlir/test/Catalyst/ScatterTest.mlir +++ b/mlir/test/Catalyst/ScatterTest.mlir @@ -26,14 +26,14 @@ func.func public @scatter_multiply(%arg0: tensor<3xf64>, %arg1: tensor) -> %2 = arith.select %0, %1, %extracted_1 : i64 %3 = arith.trunci %2 : i64 to i32 %from_elements = tensor.from_elements %3 : tensor<1xi32> - %4 = "mhlo.scatter"(%arg0, %from_elements, %cst) ({ + %4 = "stablehlo.scatter"(%arg0, %from_elements, %cst) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_2 = tensor.extract %arg2[] : tensor %extracted_3 = tensor.extract %arg3[] : tensor %5 = arith.mulf %extracted_2, %extracted_3 : f64 %from_elements_4 = tensor.from_elements %5 : tensor - mhlo.return %from_elements_4 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> + stablehlo.return %from_elements_4 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> return %4 : tensor<3xf64> } @@ -74,14 +74,14 @@ func.func public @two_scatter(%arg0: tensor<3xf64>, %arg1: tensor) -> tenso %2 = arith.select %0, %1, %extracted_2 : i64 %3 = arith.trunci %2 : i64 to i32 %from_elements = tensor.from_elements %3 : tensor<1xi32> - %4 = "mhlo.scatter"(%arg0, %from_elements, %cst_0) ({ + %4 = "stablehlo.scatter"(%arg0, %from_elements, %cst_0) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_7 = tensor.extract %arg2[] : tensor %extracted_8 = tensor.extract %arg3[] : tensor %12 = arith.mulf %extracted_7, %extracted_8 : f64 %from_elements_9 = tensor.from_elements %12 : tensor - mhlo.return %from_elements_9 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> + stablehlo.return %from_elements_9 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> %extracted_3 = tensor.extract %arg1[] : tensor %5 = arith.cmpi slt, %extracted_3, %c0_i64 : i64 %extracted_4 = tensor.extract %arg1[] : tensor @@ -90,14 +90,14 @@ func.func public @two_scatter(%arg0: tensor<3xf64>, %arg1: tensor) -> tenso %7 = arith.select %5, %6, %extracted_5 : i64 %8 = arith.trunci %7 : i64 to i32 %from_elements_6 = tensor.from_elements %8 : tensor<1xi32> - %9 = "mhlo.scatter"(%arg0, %from_elements_6, %cst) ({ + %9 = "stablehlo.scatter"(%arg0, %from_elements_6, %cst) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_7 = tensor.extract %arg2[] : tensor %extracted_8 = tensor.extract %arg3[] : tensor %12 = arith.addf %extracted_7, %extracted_8 : f64 %from_elements_9 = tensor.from_elements %12 : tensor - mhlo.return %from_elements_9 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> + stablehlo.return %from_elements_9 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xf64>, tensor<1xi32>, tensor) -> tensor<3xf64> %10 = tensor.empty() : tensor<3xf64> %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%4, %9 : tensor<3xf64>, tensor<3xf64>) outs(%10 : tensor<3xf64>) { ^bb0(%in: f64, %in_7: f64, %out: f64): @@ -149,15 +149,15 @@ func.func public @two_scatter(%arg0: tensor<3xf64>, %arg1: tensor) -> tenso func.func public @full_example_scatter(%input: tensor<3x4x2xi64>, %update: tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64> attributes {llvm.emit_c_interface} { %scatter_indices = arith.constant dense<2> : tensor<2x3x2xi32> - %result = "mhlo.scatter"(%input, %scatter_indices, %update) ({ + %result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({ ^bb0(%arg2: tensor, %arg3: tensor): %extracted_1 = tensor.extract %arg2[] : tensor %extracted_2 = tensor.extract %arg3[] : tensor %1 = arith.addi %extracted_1, %extracted_2 : i64 %from_elements = tensor.from_elements %1 : tensor - mhlo.return %from_elements : tensor + stablehlo.return %from_elements : tensor }) { - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [2, 3], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], @@ -205,7 +205,7 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> { ^bb0(%in: i32, %out: i32): linalg.yield %in : i32 } -> tensor<2x1xi32> - %2 = "mhlo.scatter"(%cst, %1, %cst_0) ({ + %2 = "stablehlo.scatter"(%cst, %1, %cst_0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %3 = tensor.empty() : tensor %4 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = []} ins(%arg1, %arg2 : tensor, tensor) outs(%3 : tensor) { @@ -213,8 +213,8 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> { %5 = arith.addf %in, %in_2 : f64 linalg.yield %5 : f64 } -> tensor - mhlo.return %4 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<4xf64>, tensor<2x1xi32>, tensor<2xf64>) -> tensor<4xf64> + stablehlo.return %4 : tensor + }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<4xf64>, tensor<2x1xi32>, tensor<2xf64>) -> tensor<4xf64> return %2 : tensor<4xf64> } @@ -252,7 +252,7 @@ func.func public @example_no_update_dim(%arg0: tensor<4xf64>) -> tensor<4xf64> { // CHECK-LABEL: @test_happy_path module @test_happy_path { - // CHECK-NOT: mhlo.scatter + // CHECK-NOT: stablehlo.scatter // CHECK-DAG: [[cst0:%.+]] = index.constant 0 // CHECK-DAG: [[inputs:%.+]] = "test.op"() : () -> tensor<[[dim1:.*]]x[[dim0:.*]]xf64> // CHECK-DAG: [[scatter_indices:%.+]] = "test.op"() : () -> tensor<1xi32> @@ -263,17 +263,17 @@ module @test_happy_path { %inputs = "test.op"() : () -> (tensor<7x5xf64>) %scatter_indices = "test.op"() : () -> (tensor<1xi32>) %updates = "test.op"() : () -> (tensor<5xf64>) - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<7x5xf64>, tensor<1xi32>, tensor<5xf64>) -> tensor<7x5xf64> "test.op"(%results) : (tensor<7x5xf64>) -> () @@ -286,17 +286,17 @@ module @test_multiple_inputs { %scatter_indices = "test.op"() : () -> (tensor<1xi32>) %updates = "test.op"() : () -> (tensor<5xf64>) // expected-error@+1 {{Only one input, update, and result}} - %results:2 = "mhlo.scatter"(%inputs, %inputs, %scatter_indices, %updates, %updates) <{ + %results:2 = "stablehlo.scatter"(%inputs, %inputs, %scatter_indices, %updates, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor): - mhlo.return %arg4, %arg6 : tensor, tensor + stablehlo.return %arg4, %arg6 : tensor, tensor }) : (tensor<7x5xf64>, tensor<7x5xf64>, tensor<1xi32>, tensor<5xf64>, tensor<5xf64>) -> (tensor<7x5xf64>, tensor<7x5xf64>) "test.op"(%results#0, %results#1) : (tensor<7x5xf64>, tensor<7x5xf64>) -> () } @@ -312,10 +312,10 @@ module @test_is_not_assignment { // CHECK-NOT: tensor.insert_slice - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] @@ -323,7 +323,7 @@ module @test_is_not_assignment { }> ({ ^bb0(%arg3: tensor, %arg4: tensor): %add = stablehlo.add %arg3, %arg4 : tensor - mhlo.return %add : tensor + stablehlo.return %add : tensor }) : (tensor<7x5xf64>, tensor<1xi32>, tensor<5xf64>) -> tensor<7x5xf64> "test.op"(%results) : (tensor<7x5xf64>) -> () } @@ -345,17 +345,17 @@ module @insert_tensor_rank_2 { // CHECK: [[idx:%.+]] = arith.index_cast [[scatter_idx]] : i32 to index // CHECK: tensor.insert_slice [[updates]] into [[inputs]][[[idx]], 0, 0] [1, [[dim1]], [[dim0]]] [1, 1, 1] : tensor<7x5xf64> into tensor<9x7x5xf64> - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0, 1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<9x7x5xf64>, tensor<1xi32>, tensor<7x5xf64>) -> tensor<9x7x5xf64> "test.op"(%results) : (tensor<9x7x5xf64>) -> () } @@ -380,17 +380,17 @@ module @two_dyn_indices { // CHECK-DAG: [[idx__1:%.+]] = arith.index_cast [[scatter_idx_1]] : i32 to index // CHECK: tensor.insert_slice [[updates]] into [[inputs]][[[idx__0]], [[idx__1]], 0] [1, 1, [[dim0]]] [1, 1, 1] : tensor<[[dim0]]xf64> into tensor<9x7x5xf64> - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1] > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<9x7x5xf64>, tensor<2xi32>, tensor<5xf64>) -> tensor<9x7x5xf64> "test.op"(%results) : (tensor<9x7x5xf64>) -> () } @@ -414,17 +414,17 @@ module @two_dyn_indices_reverted { // CHECK-DAG: [[idx__0:%.+]] = arith.index_cast [[scatter_idx_0]] : i32 to index // CHECK-DAG: [[idx__1:%.+]] = arith.index_cast [[scatter_idx_1]] : i32 to index - %results = "mhlo.scatter"(%inputs, %scatter_indices, %updates) <{ + %results = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{ indices_are_sorted = true, unique_indices = true, - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [0], inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [1, 0] // This line is changed > }> ({ ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor + stablehlo.return %arg4 : tensor }) : (tensor<9x7x5xf64>, tensor<2xi32>, tensor<5xf64>) -> tensor<9x7x5xf64> "test.op"(%results) : (tensor<9x7x5xf64>) -> () From 163b0f5051987ac9e7fe76557f10109e158c7e6e Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 10:53:29 -0400 Subject: [PATCH 07/63] remove mhlo steps in main CI script --- .github/workflows/check-catalyst.yaml | 95 +-------------------------- 1 file changed, 1 insertion(+), 94 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index c612b488cc..fef9e4ab45 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -172,81 +172,6 @@ jobs: COMPILER_LAUNCHER="" \ make llvm - mhlo: - name: MHLO Dialect Build - needs: [constants, llvm, determine_runner] - runs-on: ${{ needs.determine_runner.outputs.runner_group }} - strategy: - matrix: - compiler: ${{ fromJson(needs.constants.outputs.compilers) }} - - steps: - - name: Checkout Catalyst repo - uses: actions/checkout@v4 - - - name: Set up Python # Ensure the "primary" python version is used - uses: actions/setup-python@v5 - with: - python-version: ${{ needs.constants.outputs.primary_python_version }} - - - name: Cache MHLO Source - id: cache-mhlo-source - uses: actions/cache@v4 - with: - path: mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source - enableCrossOsArchive: true - - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - uses: actions/checkout@v4 - with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: mlir/mlir-hlo - - - name: Cache MHLO Build - id: cache-mhlo - uses: actions/cache@v4 - with: - path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0 - - - name: Get Cached LLVM Source - id: cache-llvm-source - if: steps.cache-mhlo.outputs.cache-hit != 'true' - uses: actions/cache@v4 - with: - path: mlir/llvm-project - key: llvm-${{ needs.constants.outputs.llvm_version }}-default-source - enableCrossOsArchive: true - fail-on-cache-miss: true - - - name: Get Cached LLVM Build - id: cache-llvm-build - if: steps.cache-mhlo.outputs.cache-hit != 'true' - uses: actions/cache@v4 - with: - path: llvm-build - key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-ci-build-${{ matrix.compiler }} - fail-on-cache-miss: true - - - name: Install Deps - if: steps.cache-mhlo.outputs.cache-hit != 'true' - run: | - sudo apt-get update - sudo apt-get install -y cmake ninja-build clang lld - - - name: Build MHLO Dialect - if: steps.cache-mhlo.outputs.cache-hit != 'true' - run: | - C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ - CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ - LLVM_BUILD_DIR="$(pwd)/llvm-build" \ - MHLO_BUILD_DIR="$(pwd)/mhlo-build" \ - COMPILER_LAUNCHER="" \ - make mhlo - enzyme: name: Enzyme Build needs: [constants, llvm, determine_runner] @@ -324,7 +249,7 @@ jobs: quantum: name: Quantum Dialects Build - needs: [constants, llvm, mhlo, enzyme, determine_runner] + needs: [constants, llvm, enzyme, determine_runner] runs-on: ${{ needs.determine_runner.outputs.runner_group }} strategy: matrix: @@ -363,23 +288,6 @@ jobs: key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-ci-build-${{ matrix.compiler }} fail-on-cache-miss: true - - name: Get Cached MHLO Source - id: cache-mhlo-source - uses: actions/cache/restore@v4 - with: - path: mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source - enableCrossOsArchive: true - fail-on-cache-miss: true - - - name: Get Cached MHLO Build - id: cache-mhlo - uses: actions/cache/restore@v4 - with: - path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0 - fail-on-cache-miss: true - - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -412,7 +320,6 @@ jobs: C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ - MHLO_BUILD_DIR="$(pwd)/mhlo-build" \ ENZYME_BUILD_DIR="$(pwd)/enzyme-build" \ DIALECTS_BUILD_DIR="$(pwd)/quantum-build" \ make dialects From 17a109e1a7de9dae364beed83d275cd1b89080ea Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 11:20:22 -0400 Subject: [PATCH 08/63] update py pipeline names --- frontend/catalyst/pipelines.py | 14 +++++++------- mlir/lib/Driver/Pipelines.cpp | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 45398adfab..38754a288b 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -223,13 +223,13 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: """Returns the list of passes to lower StableHLO to upstream MLIR dialects.""" hlo_lowering = [ "canonicalize", - "func.func(chlo-legalize-to-hlo)", - "stablehlo-legalize-to-hlo", - "func.func(mhlo-legalize-control-flow)", - "func.func(hlo-legalize-to-linalg)", - "func.func(mhlo-legalize-to-std)", - "func.func(hlo-legalize-sort)", - "convert-to-signless", + "func.func(chlo-legalize-to-stablehlo)", + #"stablehlo-legalize-to-hlo", + #"func.func(mhlo-legalize-control-flow)", + "func.func(stablehlo-legalize-to-linalg)", + #"func.func(mhlo-legalize-to-std)", + #"func.func(hlo-legalize-sort)", + "stablehlo-convert-to-signless", "canonicalize", "scatter-lowering", "hlo-custom-call-lowering", diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 35f74a1e1e..383347bcef 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -40,13 +40,13 @@ void createEnforceRuntimeInvariantsPipeline(OpPassManager &pm) void createHloLoweringPipeline(OpPassManager &pm) { pm.addPass(mlir::createCanonicalizerPass()); - //pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); + pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); //pm.addPass(stablehlo::createStablehloLegalizeToHloPass()); //pm.addNestedPass(stablehlo::createLegalizeControlFlowPass()); - //pm.addNestedPass(stablehlo::createLegalizeHloToLinalgPass()); + // (?) pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); //pm.addNestedPass(stablehlo::createLegalizeToStdPass()); //pm.addNestedPass(stablehlo::createLegalizeSortPass()); - //pm.addPass(stablehlo::createConvertToSignlessPass()); + // (!) pm.addPass(stablehlo::createConvertToSignlessPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createScatterLoweringPass()); pm.addPass(catalyst::createHloCustomCallLoweringPass()); From 264265332535a1a19e450909e023eb9c21d9ebbb Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 13:48:00 -0400 Subject: [PATCH 09/63] add back -legalize-to-sort pass --- frontend/catalyst/pipelines.py | 2 +- mlir/include/Catalyst/Transforms/Passes.h | 1 + mlir/include/Catalyst/Transforms/Passes.td | 18 + .../Transforms/BufferDeallocation.cpp | 8 - mlir/lib/Catalyst/Transforms/CMakeLists.txt | 1 + .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 + .../Transforms/stablehlo_legalize_sort.cpp | 608 ++++++++++++++++++ mlir/lib/Driver/Pipelines.cpp | 2 +- 8 files changed, 631 insertions(+), 10 deletions(-) create mode 100644 mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 38754a288b..2eea6f2edd 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -228,7 +228,7 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: #"func.func(mhlo-legalize-control-flow)", "func.func(stablehlo-legalize-to-linalg)", #"func.func(mhlo-legalize-to-std)", - #"func.func(hlo-legalize-sort)", + "func.func(stablehlo-legalize-sort)", "stablehlo-convert-to-signless", "canonicalize", "scatter-lowering", diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index f7872961ad..271a05c758 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -37,6 +37,7 @@ std::unique_ptr createQnodeToAsyncLoweringPass(); std::unique_ptr createRegisterInactiveCallbackPass(); std::unique_ptr createScatterLoweringPass(); std::unique_ptr createSplitMultipleTapesPass(); +std::unique_ptr createStablehloLegalizeSortPass(); void registerAllCatalystPasses(); diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index e731cecddc..14139ae503 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -282,4 +282,22 @@ def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> { let constructor = "mlir::bufferization::createBufferDeallocationPass()"; } +// mhlo legalize sort pass. +// mhlo dropped the -legalize-sort pass when migrating to stablehlo. +// We manually add it back. +// +// This pass has been modified from its original form in the tensorflow/mlir-hlo repository at +// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/mhlo_passes.td +// released under the Apache License, Version 2.0, with the following copyright notice: +// +// * Licensed under the Apache License, Version 2.0 (the "License"); + +def StablehloLegalizeSortPass : Pass<"stablehlo-legalize-sort", "func::FuncOp"> { + let summary = "Legalize from Stablehlo sort to SCF control flow."; + let constructor = "createStablehloLegalizeSortPass()"; + let dependentDialects = ["arith::ArithDialect", + "bufferization::BufferizationDialect", + "scf::SCFDialect", "tensor::TensorDialect"]; +} + #endif // CATALYST_PASSES diff --git a/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp b/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp index 012799bec7..04e077f94e 100644 --- a/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferDeallocation.cpp @@ -89,7 +89,6 @@ #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" -// using namespace llvm; using namespace mlir; using namespace catalyst; @@ -101,13 +100,6 @@ namespace catalyst { } // namespace catalyst -// namespace mlir { -// namespace bufferization { -// #define GEN_PASS_DEF_BUFFERDEALLOCATION -// #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" -// } // namespace bufferization -// } // namespace mlir - using namespace mlir; using namespace mlir::bufferization; diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 7c4809bc11..a6313f4682 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ file(GLOB SRC scatter_lowering.cpp ScatterPatterns.cpp SplitMultipleTapes.cpp + stablehlo_legalize_sort.cpp TBAAPatterns.cpp TBAATagsPass.cpp ) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 2fa570fa96..74031733f7 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -67,4 +67,5 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createGatesToPulsesPass); mlir::registerPass(catalyst::createLoopBoundaryOptimizationPass); mlir::registerPass(catalyst::createMBQCConversionPass); + mlir::registerPass(catalyst::createStablehloLegalizeSortPass); } diff --git a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp new file mode 100644 index 0000000000..e821369dda --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp @@ -0,0 +1,608 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// This file is taken from the +// tensorflow/mlir-hlo +// repository, under the Apache 2.0 License, at +// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_sort/legalize_sort.cc +// with the following copyright notice: + + /* Copyright 2019 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +// The modifications are porting the pass from the upstream MHLO namespace to +// catalyst namespace. + +// This file implements logic for lowering stablehlo.sort to the SCF dialect. +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/Passes.h" + +using namespace mlir; +using namespace stablehlo; +using namespace catalyst; + +namespace catalyst { + +#define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS +#include "Catalyst/Transforms/Passes.h.inc" + +} // namespace catalyst + +namespace { + +using ::mlir::arith::AddIOp; +using ::mlir::arith::MinSIOp; +using ::mlir::arith::SelectOp; + +constexpr int64_t kInsertionSortSize = 16; + +// Inlines the `comparator` region (without terminator) at the current insertion +// point, replacing the arguments with the given values from `lhs` and `rhs`. +Value emitComparison(ImplicitLocOpBuilder& b, SmallVector& lhs, + SmallVector& rhs, Region& comparator) { + assert(comparator.hasOneBlock() && "Comparator must have only one block."); + Block& block = comparator.front(); + assert(block.getTerminator()->getOperands().size() == 1 && + "Comparator must return a single value"); + + IRMapping mapping; + for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { + Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; + Type type = RankedTensorType::get({}, value.getType()); + mapping.map(arg, b.create(type, value)); + } + + for (Operation& op : block.without_terminator()) b.clone(op, mapping); + Value result = mapping.lookup(block.getTerminator()->getOperands().front()); + + return b.create(result, ValueRange()); +} + +// Emits a binary search of `pivots` in `arrayMemrefs` (all rank 1) in the range +// [`left`;`right`). `arrayMemrefs` must be sorted according to `comparator`. +Value emitBinarySearch(ImplicitLocOpBuilder& b, Value leftInit, Value rightInit, + SmallVector& pivots, ValueRange arrayMemrefs, + Region& comparator) { + SmallVector types{leftInit.getType(), rightInit.getType()}; + ArithBuilder arith(b, b.getLoc()); + + // while ( + auto whileOp = + b.create(types, SmallVector{leftInit, rightInit}); + OpBuilder::InsertionGuard guard(b); + + // left < right) { + Block* before = b.createBlock(&whileOp.getBefore(), {}, types, + {whileOp.getLoc(), whileOp.getLoc()}); + { + Value left = before->getArgument(0), right = before->getArgument(1); + b.setInsertionPointToEnd(before); + b.create(arith.slt(left, right), before->getArguments()); + } + + Block* after = b.createBlock(&whileOp.getAfter(), {}, types, + {whileOp.getLoc(), whileOp.getLoc()}); + { + Value left = after->getArgument(0), right = after->getArgument(1); + b.setInsertionPointToEnd(after); + // int mid = (left + right) >> 1; + Value one = b.create(1); + Value mid = b.create(arith.add(left, right), one); + Value midPlusOne = b.create(mid, one); + + auto arraysAtMid = llvm::to_vector( + llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { + return b.create(arrayMemref, mid); + })); + Value cond = emitComparison(b, pivots, arraysAtMid, comparator); + // if (comparator(pivot, array[mid])) + // right = mid; + // else + // left = mid + 1; + Value newLeft = arith.select(cond, left, midPlusOne); + Value newRight = arith.select(cond, mid, right); + + // } + b.create(ValueRange{newLeft, newRight}); + } + + return whileOp.getResult(0); +} + +SmallVector loadTensorElements(ImplicitLocOpBuilder& b, + ValueRange tensors, Value index) { + return llvm::to_vector(llvm::map_range(tensors, [&](Value tensor) -> Value { + return b.create(tensor, index); + })); +} + +SmallVector loadMemrefElements(ImplicitLocOpBuilder& b, + ValueRange memrefs, Value index) { + return llvm::to_vector(llvm::map_range(memrefs, [&](Value memref) -> Value { + Type type = mlir::cast(memref.getType()).getElementType(); + return b.create(type, memref, index); + })); +} + +void storeMemrefElements(ImplicitLocOpBuilder& b, ValueRange memrefs, + Value index, ValueRange values) { + for (auto [value, memref] : llvm::zip(values, memrefs)) { + b.create(value, memref, index); + } +} + +// Insertion sorts `inputTensors` in the range [`lo`; `hi`), storing the results +// in `outputMemrefs`. `inputTensors` and `outputMemrefs` must all be rank 1 and +// of identical size. +void emitInsertionSort(ImplicitLocOpBuilder& b, Value lo, Value hi, + ValueRange inputTensors, ValueRange outputMemrefs, + mlir::Region& comparator) { + ArithBuilder arith(b, b.getLoc()); + Value zero = b.create(0); + Value one = b.create(1); + + // array[lo] = tensors[lo]; + storeMemrefElements(b, outputMemrefs, lo, + loadTensorElements(b, inputTensors, lo)); + + // for (int start = lo + 1; start < hi; ++start) + { + auto forOp = b.create(arith.add(lo, one), hi, one); + OpBuilder::InsertionGuard outerGuard(b); + b.setInsertionPointToStart(forOp.getBody()); + Value start = forOp.getInductionVar(); + + // T pivot = tensors[start]; + auto pivots = loadTensorElements(b, inputTensors, start); + + // int index = binarySearch(lo, start, pivot, array, comparator); + auto index = + emitBinarySearch(b, lo, start, pivots, outputMemrefs, comparator); + + // int n = start - index; // The number of elements to move + Value n = arith.sub(start, index); + + // memmove(&array[index + 1], &array[index], n * sizeof(T)) + // memref::CopyOp would be nice to use here, but: + // 1. It lowers to a quite inefficient library call in the general case + // (strides != 1). + // 2. It implements memcpy semantics, but we need memmove here. + // So we go with a loop instead. + auto copyForOp = b.create(zero, n, one); + { + OpBuilder::InsertionGuard innerGuard(b); + b.setInsertionPointToStart(copyForOp.getBody()); + Value copyLoopIndex = copyForOp.getBody()->getArgument(0); + + Value dstIndex = arith.sub(start, copyLoopIndex); + Value srcIndex = arith.sub(dstIndex, one); + storeMemrefElements(b, outputMemrefs, dstIndex, + loadMemrefElements(b, outputMemrefs, srcIndex)); + } + // array[index] = pivot; + storeMemrefElements(b, outputMemrefs, index, pivots); + } +} + +void emitMerge(ImplicitLocOpBuilder& b, Value lo, Value mid, Value hi, + ValueRange readBufs, ValueRange writeBufs, + mlir::Region& comparator) { + ArithBuilder arith(b, b.getLoc()); + // The while loop runs until we reach the end of either interval. It has three + // loop-carried variables: + // 1. current output index + // 2. current read index for interval 1 + // 3. current read index for interval 2 + SmallVector whileArgTypes{lo.getType(), lo.getType(), mid.getType()}; + SmallVector whileInitArgs{lo, lo, mid}; + SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); + + // while( + auto whileOp = b.create(whileArgTypes, whileInitArgs); + { + OpBuilder::InsertionGuard guard(b); + { + Block* before = + b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); + Value i0 = before->getArgument(1), i1 = before->getArgument(2); + b.setInsertionPointToEnd(before); + + // i0 < mid && i1 < hi) { + Value inbounds0 = arith.slt(i0, mid); + Value inbounds1 = arith.slt(i1, hi); + + b.create(arith._and(inbounds0, inbounds1), + before->getArguments()); + } + + { + Block* after = + b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); + Value iOut = after->getArgument(0), i0 = after->getArgument(1), + i1 = after->getArgument(2); + b.setInsertionPointToEnd(after); + + // auto vals0 = readBufs[i0], vals1 = readBufs[i1]; + SmallVector vals0 = loadMemrefElements(b, readBufs, i0); + SmallVector vals1 = loadMemrefElements(b, readBufs, i1); + + // writeBufs[iOut] = comparator(vals1, vals0) + // ? readBufs[i1++] : readBufs[i0++]; + Value cmp = emitComparison(b, vals1, vals0, comparator); + SmallVector pickedVals; + for (auto [val0, val1] : llvm::zip(vals0, vals1)) { + pickedVals.push_back(b.create(cmp, val1, val0)); + } + storeMemrefElements(b, writeBufs, iOut, pickedVals); + + Value one = b.create(1); + Value nexti0 = b.create(cmp, i0, arith.add(i0, one)); + Value nexti1 = b.create(cmp, arith.add(i1, one), i1); + // ++iOut; + Value nextIOut = b.create(iOut, one); + b.create(ValueRange{nextIOut, nexti0, nexti1}); + } + } + + // At this point, exactly one of the input ranges will have leftover elements. + Value iOut = whileOp->getResult(0); + Value i0 = whileOp->getResult(1); + Value i1 = whileOp->getResult(2); + + // We could use memref::CopyOp here, but typically, there aren't many leftover + // elements for randomly shuffled inputs. + Value leftoverIn0 = arith.slt(i0, mid); + Value start = arith.select(leftoverIn0, i0, i1); + Value end = arith.select(leftoverIn0, mid, hi); + Value n = arith.sub(end, start); + + Value zero = b.create(0); + Value one = b.create(1); + auto forOp = b.create(zero, n, one); + b.setInsertionPointToStart(forOp.getBody()); + Value copyIndex = forOp.getBody()->getArgument(0); + + Value srcIndex = arith.add(start, copyIndex); + Value dstIndex = arith.add(iOut, copyIndex); + storeMemrefElements(b, writeBufs, dstIndex, + loadMemrefElements(b, readBufs, srcIndex)); +} + +// Emits a bottom up merge sort of `inputTensors` in the range [`lo`; `hi`), and +// writes the results to either `outputs0` or `outputs1`. +// Returns 0 if the results are in `outputs0`, 1 if they are in `outputs1`. +// TODO(jreiffers): Consider implementing top-down merge sort. +Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, + int64_t staticSortDimSize, ValueRange inputTensors, + ValueRange outputs0, ValueRange outputs1, + mlir::Region& comparator) { + ArithBuilder arith(b, b.getLoc()); + Value size = arith.sub(hi, lo); + + Value zero = b.create(0); + Value insertionSortSize = + b.create(kInsertionSortSize); + + // Run insertion sort on blocks of size kInsertionSortSize. + // for (int start = 0; start < size; start += kInsertionSortSize) { + { + auto forOp = b.create(zero, size, insertionSortSize); + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(forOp.getBody()); + Value start = forOp.getBody()->getArgument(0); + Value end = arith.add( + b.create(arith.add(start, insertionSortSize), size), lo); + emitInsertionSort(b, start, end, inputTensors, outputs0, comparator); + } + + Value initParity = b.create(0, 1); + if (staticSortDimSize >= 0 && staticSortDimSize < kInsertionSortSize) { + return initParity; + } + + // The while arguments are: + // 1. the current size + // 2. the original index of the buffers we're currently reading from + // 3. the buffers we're currently reading from + // 4. the buffers we're currently writing to. + // + // 1 gets doubled each iteration, 2 gets negated, 3 and 4 are swapped. + // int currentSize = 16; + SmallVector whileInitArgs{insertionSortSize, initParity}; + // First we read from `outputs0` (initialized by the insertion sort above). + llvm::copy(outputs0, std::back_inserter(whileInitArgs)); + llvm::copy(outputs1, std::back_inserter(whileInitArgs)); + + SmallVector whileArgTypes; + for (auto val : whileInitArgs) whileArgTypes.push_back(val.getType()); + + SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); + + // while ( + auto whileOp = b.create(whileArgTypes, whileInitArgs); + OpBuilder::InsertionGuard guard(b); + + // currentSize < totalSize) + { + Block* before = + b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); + Value currentSize = before->getArgument(0); + b.setInsertionPointToEnd(before); + b.create(arith.slt(currentSize, size), + before->getArguments()); + } + + size_t numArgs = inputTensors.size(); + // { + { + Block* after = + b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); + + Value currentSize = after->getArgument(0); + Value parity = after->getArgument(1); + auto readBufs = after->getArguments().drop_front(2).take_front(numArgs); + auto writeBufs = after->getArguments().take_back(numArgs); + + Value twoCurrentSize = arith.add(currentSize, currentSize); + + // for (int start = 0; start < size; start += 2*currentSize) { + { + auto forOp = b.create(zero, size, twoCurrentSize); + b.setInsertionPointToStart(forOp.getBody()); + Value start = forOp.getBody()->getArgument(0); + + Value mid = b.create(size, arith.add(start, currentSize)); + Value end = b.create(size, arith.add(start, twoCurrentSize)); + emitMerge(b, start, mid, end, readBufs, writeBufs, comparator); + b.setInsertionPointAfter(forOp); + } + // } + + // parity = !parity; + Value one = b.create(1, 1); + Value notParity = arith.sub(one, parity); + // currentSize *= 2; + SmallVector nextWhileArgs{twoCurrentSize, notParity}; + llvm::copy(writeBufs, std::back_inserter(nextWhileArgs)); + llvm::copy(readBufs, std::back_inserter(nextWhileArgs)); + b.create(nextWhileArgs); + } + // } + + // The result is the parity bit. + return whileOp.getResults().drop_front(1).front(); +} + +// Helper struct for extracting 1d slices from tensors and memrefs. +struct Slicer { + Slicer(OpBuilder& b, uint64_t sortDim, Value sortDimSize, ValueRange ivs) + : sizes(ivs.size() + 1, b.getI64IntegerAttr(1)), + strides(ivs.size() + 1, b.getI64IntegerAttr(1)) { + sizes[sortDim] = sortDimSize; + for (size_t i = 0; i < ivs.size() + 1; ++i) { + if (i == sortDim) { + offsets.push_back(b.getI64IntegerAttr(0)); + } else { + offsets.push_back(ivs[i - static_cast(i > sortDim)]); + } + } + } + + RankedTensorType toSlicedType(RankedTensorType sourceType) { + return tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( + /*resultRank=*/1, sourceType, offsets, sizes, strides); + } + + MemRefType toSlicedType(MemRefType sourceType) { + return mlir::cast(memref::SubViewOp::inferRankReducedResultType( + {ShapedType::kDynamic} /*1D output*/, sourceType, offsets, sizes, + strides)); + } + + template + Value slice(ImplicitLocOpBuilder& b, Value input) { + Ty ty = mlir::cast(input.getType()); + return b.create(toSlicedType(ty), input, offsets, sizes, strides) + .getResult(); + } + + Value apply(ImplicitLocOpBuilder& b, Value input) { + Type inTy = input.getType(); + if (mlir::isa(inTy)) { + return slice(b, input); + } + assert(mlir::isa(inTy)); + return slice(b, input); + } + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +SmallVector sliceMemrefsOrTensors(ImplicitLocOpBuilder& b, + SmallVector& ivs, + Value sortDimSize, + ValueRange memrefsOrTensors, + SortOp op) { + if (ivs.empty()) return memrefsOrTensors; + + SmallVector outputs; + Slicer slicer(b, op.getDimension(), sortDimSize, ivs); + // Create subviews/slices. + for (Value out : memrefsOrTensors) { + outputs.push_back(slicer.apply(b, out)); + } + + return outputs; +} + +struct SortOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + // Note: the output memrefs aren't necessarily the ones that we return, + SmallVector outputMemrefs; + SmallVector scratchMemrefs; + + Value firstOperand = op.getOperands().front(); + auto firstOperandType = mlir::cast(firstOperand.getType()); + int64_t inputRank = firstOperandType.getRank(); + + Value sortDimSize = b.createOrFold( + firstOperand, b.create(op.getDimension())); + int64_t staticSortDimSize = firstOperandType.getDimSize(op.getDimension()); + + SmallVector dynamicDims; + for (int i = 0; i < inputRank; ++i) { + if (!firstOperandType.isDynamicDim(i)) continue; + Value index = b.create(i); + Value dimOp = b.create(firstOperand, index); + dynamicDims.push_back(dimOp); + } + + // Allocate output and scratch memrefs. If the size of the sort dimension is + // statically known to be <= kInsertionSortSize, `scratchMemrefs` are unused + // and will be cleaned up later. + for (auto input : op.getOperands()) { + auto inputType = mlir::cast(input.getType()); + auto memRefType = + MemRefType::get(inputType.getShape(), inputType.getElementType()); + + outputMemrefs.push_back( + b.create(memRefType, dynamicDims)); + scratchMemrefs.push_back( + b.create(memRefType, dynamicDims)); + } + + b.setInsertionPoint(op); + Value zero = b.create(0); + Value one = b.create(1); + + Value forInitArg = b.create(0, 1); + SmallVector forOps; + SmallVector ivs; + forOps.reserve(inputRank - 1); + ivs.reserve(inputRank - 1); + for (int64_t i = 0; i < inputRank; ++i) { + if (i != static_cast(op.getDimension())) { + Value dim = b.create(i); + Value ub = b.create(firstOperand, dim); + scf::ForOp& forOp = forOps.emplace_back( + b.create(zero, ub, one, ValueRange{forInitArg})); + ivs.push_back(forOp.getInductionVar()); + b.setInsertionPointToStart(&forOp.getRegion().front()); + } + } + SmallVector inputs = + sliceMemrefsOrTensors(b, ivs, sortDimSize, op.getOperands(), op); + SmallVector outputs = + sliceMemrefsOrTensors(b, ivs, sortDimSize, outputMemrefs, op); + SmallVector scratches = + sliceMemrefsOrTensors(b, ivs, sortDimSize, scratchMemrefs, op); + + Value parity = + emitBottomUpMergeSort(b, zero, sortDimSize, staticSortDimSize, inputs, + outputs, scratches, op.getRegion()); + + // Pass the parity bit through the for loops. + for (auto i = static_cast(forOps.size() - 1); i >= 0; --i) { + b.setInsertionPointToEnd(&forOps[i].getRegion().front()); + b.create(ValueRange{parity}); + parity = forOps[i]->getResult(0); + } + b.setInsertionPoint(op); + + SmallVector outputTensors; + for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) { + outputTensors.push_back(b.create( + b.create(parity, out1, out0), /*restrict=*/true)); + } + + rewriter.replaceOp(op, outputTensors); + return success(); + } +}; + +struct StablehloLegalizeSortPass + : public catalyst::impl::StablehloLegalizeSortPassBase { + // Perform the lowering to MLIR control flow. + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext* ctx = f.getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + mlir::ConversionTarget target(*ctx); + target.markUnknownOpDynamicallyLegal([](Operation*) { return true; }); + target.addIllegalOp(); + + if (failed(applyPartialConversion(f, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr catalyst::createStablehloLegalizeSortPass() +{ + return std::make_unique(); +} diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 383347bcef..175eed5480 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -45,7 +45,7 @@ void createHloLoweringPipeline(OpPassManager &pm) //pm.addNestedPass(stablehlo::createLegalizeControlFlowPass()); // (?) pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); //pm.addNestedPass(stablehlo::createLegalizeToStdPass()); - //pm.addNestedPass(stablehlo::createLegalizeSortPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); // (!) pm.addPass(stablehlo::createConvertToSignlessPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createScatterLoweringPass()); From 4e45fa611c71e4102f8db57b2ba6a235bb52b459 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 14:11:44 -0400 Subject: [PATCH 10/63] checkout stablehlo submodule in CI --- .github/workflows/check-catalyst.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index fef9e4ab45..a3bb570840 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -314,6 +314,14 @@ jobs: key: ${{ runner.os }}-ccache-${{ github.run_id }} restore-keys: ${{ runner.os }}-ccache- + # just hard code commit manually, set up this stablehlo dep verisons infra later + - name: Clone Stablehlo Submodule + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: f1f035fea33dcfdd7c471eb7f39174b344003117 + path: mlir/stablehlo + - name: Build MLIR Dialects run: | CCACHE_DIR="$(pwd)/.ccache" \ From b5ddbd6029b0c67c2b3797a261b11c4e54e307dd Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 14:38:02 -0400 Subject: [PATCH 11/63] turn off warnings as errors in stablehlo... is this how you do it? --- mlir/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 1fd05b4bbb..547956febd 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -195,4 +195,5 @@ add_subdirectory(cmake/modules) add_subdirectory(test) unset(LLVM_USE_LINKER) set(STABLEHLO_BUILD_EMBEDDED ON) +set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) add_subdirectory(stablehlo) From ae895d0baee8022dc69db237594b10f390b28923 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 16:51:10 -0400 Subject: [PATCH 12/63] add back stablehlo legalize to std pass --- frontend/catalyst/pipelines.py | 2 +- .../Catalyst/Transforms/CMakeLists.txt | 7 + mlir/include/Catalyst/Transforms/Passes.h | 1 + mlir/include/Catalyst/Transforms/Passes.td | 15 + ...stablehlo_legalize_to_standard_patterns.td | 119 ++++++++ mlir/lib/Catalyst/Transforms/CMakeLists.txt | 2 + .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 + .../Transforms/stablehlo_legalize_to_std.cpp | 264 ++++++++++++++++++ mlir/lib/Driver/Pipelines.cpp | 2 +- 9 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td create mode 100644 mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 2eea6f2edd..2e7b53ee90 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -227,7 +227,7 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: #"stablehlo-legalize-to-hlo", #"func.func(mhlo-legalize-control-flow)", "func.func(stablehlo-legalize-to-linalg)", - #"func.func(mhlo-legalize-to-std)", + "func.func(stablehlo-legalize-to-std)", "func.func(stablehlo-legalize-sort)", "stablehlo-convert-to-signless", "canonicalize", diff --git a/mlir/include/Catalyst/Transforms/CMakeLists.txt b/mlir/include/Catalyst/Transforms/CMakeLists.txt index 52802b77fc..1b4879ff71 100644 --- a/mlir/include/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/include/Catalyst/Transforms/CMakeLists.txt @@ -2,3 +2,10 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name catalyst) add_public_tablegen_target(MLIRCatalystPassIncGen) add_mlir_doc(Passes CatalystPasses ./ -gen-pass-doc) + +# The following is taken from mhlo to build the --legalize-to-std pass +set(LLVM_TARGET_DEFINITIONS stablehlo_legalize_to_standard_patterns.td) +include_directories( + ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) +mlir_tablegen(generated_stablehlo_legalize_to_standard.cpp.inc -gen-rewriters) +add_public_tablegen_target(MLIRStablehloLegalizeToStandardIncGen) diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index 271a05c758..63778d65ce 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -38,6 +38,7 @@ std::unique_ptr createRegisterInactiveCallbackPass(); std::unique_ptr createScatterLoweringPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createStablehloLegalizeSortPass(); +std::unique_ptr createStablehloLegalizeToStdPass(); void registerAllCatalystPasses(); diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 14139ae503..1971daf3ed 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -300,4 +300,19 @@ def StablehloLegalizeSortPass : Pass<"stablehlo-legalize-sort", "func::FuncOp"> "scf::SCFDialect", "tensor::TensorDialect"]; } +// mhlo legalize to std pass. +// mhlo dropped the -legalize-to-std pass when migrating to stablehlo. +// We manually add it back. +// +// This pass has been modified from its original form in the tensorflow/mlir-hlo repository at +// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/mhlo_passes.td +// released under the Apache License, Version 2.0, with the following copyright notice: +// +// * Licensed under the Apache License, Version 2.0 (the "License"); + +def StablehloLegalizeToStandardPass : Pass<"stablehlo-legalize-to-std", "func::FuncOp"> { + let summary = "Legalize from MHLO dialect to standard dialect."; + let constructor = "createStablehloLegalizeToStdPass()"; +} + #endif // CATALYST_PASSES diff --git a/mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td b/mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td new file mode 100644 index 0000000000..a26ecde9cb --- /dev/null +++ b/mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td @@ -0,0 +1,119 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// This file is taken from the +// tensorflow/mlir-hlo +// repository, under the Apache 2.0 License, at +// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td +// with the following copyright notice: + + /* Copyright 2019 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + + + +// This is the legalization pattern definition file for MHLO to StandardOps. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "mlir/Dialect/Math/IR/MathOps.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "stablehlo/dialect/StablehloOps.td" + +//===----------------------------------------------------------------------===// +// Nullary op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(StableHLO_ConstantOp ElementsAttr:$value), + (Arith_ConstantOp $value)>; + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +def IsSameSizePred : CPred< + "cast($0.getType()).getShape() " + "== cast($1.getType()).getShape()">; +def IsSameSizeConstraint : Constraint; +def createFastMathNone : NativeCodeCall< + "::mlir::arith::FastMathFlagsAttr::get(" + "$_builder.getContext(), ::mlir::arith::FastMathFlags::none" + ")">; +def createOverflowNone : NativeCodeCall< + "::mlir::arith::IntegerOverflowFlagsAttr::get(" + "$_builder.getContext(), ::mlir::arith::IntegerOverflowFlags::none" + ")">; +def createDenormalIEEE : NativeCodeCall< + "::mlir::arith::DenormalModeAttr::get(" + "$_builder.getContext(), ::mlir::arith::DenormalMode::ieee" + ")">; + + +// Unary Lowering Patterns. +def : Pat<(StableHLO_CeilOp HLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; + +// Binary Lowering Patterns. +def : Pat<(StableHLO_AndOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (Arith_AndIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_OrOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (Arith_OrIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (Arith_AddFOp $l, $r, (createFastMathNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_SubtractOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (Arith_SubFOp $l, $r, (createFastMathNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (Arith_MulFOp $l, $r, (createFastMathNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (Arith_DivFOp $l, $r, (createFastMathNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (Arith_RemFOp $l, $r, (createFastMathNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (Arith_AddIOp $l, $r, (createOverflowNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_SubtractOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (Arith_SubIOp $l, $r, (createOverflowNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (Arith_MulIOp $l, $r, (createOverflowNone )), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (Arith_DivSIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(StableHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (Arith_RemSIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; + +def : Pat<(StableHLO_SelectOp $pred, $tv, $fv), + (SelectOp $pred, $tv, $fv), + [(IsSameSizeConstraint $pred, $tv), (IsSameSizeConstraint $tv, $fv)]>; diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index a6313f4682..2a29dac8e4 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -26,6 +26,7 @@ file(GLOB SRC ScatterPatterns.cpp SplitMultipleTapes.cpp stablehlo_legalize_sort.cpp + stablehlo_legalize_to_std.cpp TBAAPatterns.cpp TBAATagsPass.cpp ) @@ -41,6 +42,7 @@ set(LIBS set(DEPENDS MLIRCatalystPassIncGen + MLIRStablehloLegalizeToStandardIncGen StablehloBaseIncGen StablehloOpsIncGen ) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 74031733f7..4e270e6736 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -68,4 +68,5 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createLoopBoundaryOptimizationPass); mlir::registerPass(catalyst::createMBQCConversionPass); mlir::registerPass(catalyst::createStablehloLegalizeSortPass); + mlir::registerPass(catalyst::createStablehloLegalizeToStdPass); } diff --git a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp new file mode 100644 index 0000000000..371696c740 --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp @@ -0,0 +1,264 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is taken from the +// tensorflow/mlir-hlo +// repository, under the Apache 2.0 License, at +// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc +// with the following copyright notice: + +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The modifications are porting the pass from the upstream MHLO namespace to +// catalyst namespace. + +// This file implements logic for lowering MHLO dialect to Standard dialect. + +#include +#include +#include + +// #include "mhlo/transforms/rewriters.h" // (??) +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/Passes.h" + +using namespace mlir; +using namespace stablehlo; +using namespace catalyst; + +namespace catalyst { + +#define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS +#include "Catalyst/Transforms/Passes.h.inc" +#include "Catalyst/Transforms/generated_stablehlo_legalize_to_standard.cpp.inc" + +} // namespace catalyst + +namespace { + +class CompareIConvert : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override + { + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + auto lhsType = mlir::cast(lhs.getType()); + auto rhsType = mlir::cast(rhs.getType()); + + // Broadcasting not supported by this rewrite. + if (lhsType.getShape() != rhsType.getShape()) + return failure(); + + if (!lhsType.getElementType().isSignlessInteger() || + !rhsType.getElementType().isSignlessInteger()) + return failure(); + + std::optional comparePredicate = std::nullopt; + switch (op.getComparisonDirection()) { + case ComparisonDirection::EQ: + comparePredicate = arith::CmpIPredicate::eq; + break; + case ComparisonDirection::NE: + comparePredicate = arith::CmpIPredicate::ne; + break; + case ComparisonDirection::LT: + comparePredicate = arith::CmpIPredicate::slt; + break; + case ComparisonDirection::LE: + comparePredicate = arith::CmpIPredicate::sle; + break; + case ComparisonDirection::GT: + comparePredicate = arith::CmpIPredicate::sgt; + break; + case ComparisonDirection::GE: + comparePredicate = arith::CmpIPredicate::sge; + break; + } + + if (!comparePredicate.has_value()) + return failure(); + + rewriter.replaceOpWithNewOp(op, comparePredicate.value(), lhs, rhs); + return success(); + } +}; + +class CompareFConvert : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override + { + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + auto lhsType = mlir::cast(lhs.getType()); + auto rhsType = mlir::cast(rhs.getType()); + + // Broadcasting not supported by this rewrite. + if (lhsType.getShape() != rhsType.getShape()) + return failure(); + + if (!mlir::isa(lhsType.getElementType()) || + !mlir::isa(rhsType.getElementType())) + return failure(); + + std::optional comparePredicate = std::nullopt; + switch (op.getComparisonDirection()) { + case ComparisonDirection::EQ: + comparePredicate = arith::CmpFPredicate::OEQ; + break; + case ComparisonDirection::NE: + comparePredicate = arith::CmpFPredicate::UNE; + break; + case ComparisonDirection::LT: + comparePredicate = arith::CmpFPredicate::OLT; + break; + case ComparisonDirection::LE: + comparePredicate = arith::CmpFPredicate::OLE; + break; + case ComparisonDirection::GT: + comparePredicate = arith::CmpFPredicate::OGT; + break; + case ComparisonDirection::GE: + comparePredicate = arith::CmpFPredicate::OGE; + break; + } + + if (!comparePredicate.has_value()) + return failure(); + + rewriter.replaceOpWithNewOp(op, comparePredicate.value(), lhs, rhs); + return success(); + } +}; + +// Replace IotaOp with an integer constant. A ConvertOp is added to +// convert the integer constant to iota result type. For complex types, the real +// part is replaced with the generated constant and the imaginary part is +// replaced with zero tensor. +class ConvertIotaOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::IotaOp op, PatternRewriter &rewriter) const override + { + auto outputType = mlir::cast(op.getType()); + auto outputSize = outputType.getNumElements(); + auto dimension = op.getIotaDimension(); + auto maxDimSize = outputType.getDimSize(dimension); + + auto elementType = outputType.getElementType(); + int bitwidth; + + auto complexTy = mlir::dyn_cast(elementType); + Type intOrFloatTy = elementType; + if (complexTy) + intOrFloatTy = complexTy.getElementType(); + + bitwidth = intOrFloatTy.getIntOrFloatBitWidth(); + llvm::SmallVector values; + values.reserve(outputSize); + + int64_t increaseStride = outputSize; + for (uint64_t i = 0; i <= dimension; i++) { + increaseStride /= outputType.getDimSize(i); + } + + int64_t currentValue = 0; + for (int i = 0; i < outputSize; i++) { + int64_t value = (currentValue / increaseStride) % maxDimSize; + values.push_back(APInt(bitwidth, value)); + ++currentValue; + } + + auto intShapeType = RankedTensorType::get( + outputType.getShape(), IntegerType::get(rewriter.getContext(), bitwidth)); + auto loc = op.getLoc(); + auto integerConst = rewriter.create( + loc, DenseIntElementsAttr::get(intShapeType, values)); + + auto intOrFloatShapeTy = RankedTensorType::get(outputType.getShape(), intOrFloatTy); + + auto iotaConst = rewriter.create(loc, intOrFloatShapeTy, integerConst); + + // For int/float types we are done, replace op and return. + if (!complexTy) { + rewriter.replaceOp(op, iotaConst.getResult()); + return success(); + } + + // For complex types, generate a constant tensor of zeroes for the imaginary + // part and use iota_const for real part. + auto zeroes = rewriter.create( + loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0))); + auto imagZeroes = rewriter.create(loc, intOrFloatShapeTy, zeroes); + rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); + return success(); + } +}; + +void populateStablehloToStdPatterns(RewritePatternSet *patterns, mlir::MLIRContext *ctx) +{ + populateWithGenerated(*patterns); + patterns->add(ctx); +} + +struct StablehloLegalizeToStandardPass + : public catalyst::impl::StablehloLegalizeToStandardPassBase { + void getDependentDialects(DialectRegistry ®istry) const override + { + registry.insert(); + } + + /// Perform the lowering to Standard dialect. + void runOnOperation() override + { + RewritePatternSet patterns(&getContext()); + populateStablehloToStdPatterns(&patterns, &getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // end anonymous namespace + +std::unique_ptr catalyst::createStablehloLegalizeToStdPass() +{ + return std::make_unique(); +} diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 175eed5480..954e18e70b 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -44,7 +44,7 @@ void createHloLoweringPipeline(OpPassManager &pm) //pm.addPass(stablehlo::createStablehloLegalizeToHloPass()); //pm.addNestedPass(stablehlo::createLegalizeControlFlowPass()); // (?) pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); - //pm.addNestedPass(stablehlo::createLegalizeToStdPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); // (!) pm.addPass(stablehlo::createConvertToSignlessPass()); pm.addPass(mlir::createCanonicalizerPass()); From c938853c9aed747103dc2982ae1e823f6f0bb283 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 17:08:11 -0400 Subject: [PATCH 13/63] add back legalize-to-control-flow pass --- frontend/catalyst/pipelines.py | 2 +- mlir/include/Catalyst/Transforms/Passes.h | 1 + mlir/include/Catalyst/Transforms/Passes.td | 16 + mlir/lib/Catalyst/Transforms/CMakeLists.txt | 1 + .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 + .../stablehlo_legalize_control_flow.cpp | 316 ++++++++++++++++++ mlir/lib/Driver/Pipelines.cpp | 2 +- 7 files changed, 337 insertions(+), 2 deletions(-) create mode 100644 mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 2e7b53ee90..1adacbc350 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -225,7 +225,7 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: "canonicalize", "func.func(chlo-legalize-to-stablehlo)", #"stablehlo-legalize-to-hlo", - #"func.func(mhlo-legalize-control-flow)", + "func.func(stablehlo-legalize-control-flow)", "func.func(stablehlo-legalize-to-linalg)", "func.func(stablehlo-legalize-to-std)", "func.func(stablehlo-legalize-sort)", diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index 63778d65ce..82957a6201 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -39,6 +39,7 @@ std::unique_ptr createScatterLoweringPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createStablehloLegalizeSortPass(); std::unique_ptr createStablehloLegalizeToStdPass(); +std::unique_ptr createStablehloLegalizeControlFlowPass(); void registerAllCatalystPasses(); diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 1971daf3ed..46080b1f74 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -315,4 +315,20 @@ def StablehloLegalizeToStandardPass : Pass<"stablehlo-legalize-to-std", "func::F let constructor = "createStablehloLegalizeToStdPass()"; } +// mhlo legalize to control flow pass. +// mhlo dropped the -legalize-to-control-flow pass when migrating to stablehlo. +// We manually add it back. +// +// This pass has been modified from its original form in the tensorflow/mlir-hlo repository at +// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/mhlo_passes.td +// released under the Apache License, Version 2.0, with the following copyright notice: +// +// * Licensed under the Apache License, Version 2.0 (the "License"); + +def StablehloLegalizeControlFlowPass : Pass<"stablehlo-legalize-control-flow", "func::FuncOp"> { + let summary = "Legalize from MHLO control flow to SCF control flow."; + let constructor = "createStablehloLegalizeControlFlowPass()"; + let dependentDialects = ["scf::SCFDialect", "tensor::TensorDialect"]; +} + #endif // CATALYST_PASSES diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 2a29dac8e4..b67727a900 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ file(GLOB SRC scatter_lowering.cpp ScatterPatterns.cpp SplitMultipleTapes.cpp + stablehlo_legalize_control_flow.cpp stablehlo_legalize_sort.cpp stablehlo_legalize_to_std.cpp TBAAPatterns.cpp diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 4e270e6736..b7d56dc84a 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -69,4 +69,5 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createMBQCConversionPass); mlir::registerPass(catalyst::createStablehloLegalizeSortPass); mlir::registerPass(catalyst::createStablehloLegalizeToStdPass); + mlir::registerPass(catalyst::createStablehloLegalizeControlFlowPass); } diff --git a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp new file mode 100644 index 0000000000..3faf312560 --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp @@ -0,0 +1,316 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is taken from the +// tensorflow/mlir-hlo +// repository, under the Apache 2.0 License, at +// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc +// with the following copyright notice: + +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The modifications are porting the pass from the upstream MHLO namespace to +// catalyst namespace. + +// This file implements logic for lowering MHLO dialect to SCF dialect. +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" +#include "llvm/Support/Casting.h" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/Passes.h" + +using namespace mlir; +using namespace stablehlo; +using namespace catalyst; + +namespace catalyst { + +#define GEN_PASS_DEF_STABLEHLOLEGALIZECONTROLFLOWPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZECONTROLFLOWPASS +#include "Catalyst/Transforms/Passes.h.inc" + +} // namespace catalyst + +namespace { + +// All transformations in this file take mhlo blocks which end with +// stablehlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an +// entire block with the only change being return -> yield. +void inlineMhloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &mhlo, Region &scf) +{ + // Remove an existing block, then move the region over. + if (!scf.empty()) + rewriter.eraseBlock(&scf.back()); + rewriter.inlineRegionBefore(mhlo, scf, scf.end()); + // Fix up the terminator. + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(&scf.back()); + auto *terminator = scf.back().getTerminator(); + rewriter.replaceOpWithNewOp(terminator, terminator->getOperands()); +} + +// mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor +// or a 1 element tensor. To handle this, collapse shape before extracting the +// scalar value when necessary. +Value extractTensorValue(OpBuilder &b, Value tensor) +{ + auto loc = tensor.getLoc(); + if (mlir::cast(tensor.getType()).hasRank() && + mlir::cast(tensor.getType()).getRank() != 0) { + tensor = + b.create(loc, tensor, SmallVector()); + } + return b.create(loc, tensor, ValueRange()); +} + +struct ScfForBounds { + Value lb; + Value ub; + Value step; + unsigned indexArgIndex; +}; + +std::optional extractForBounds(stablehlo::WhileOp op) +{ + auto &cond = op.getCond().front(); + auto &body = op.getBody().front(); + if (cond.getOperations().size() != 2) + return std::nullopt; + + auto matchBbArg = [](Value v, Block &block) -> std::optional { + if (!mlir::isa(v) || v.getParentBlock() != &block) + return std::nullopt; + return mlir::cast(v).getArgNumber(); + }; + + auto compare = llvm::dyn_cast(cond.front()); + // If the rhs of the comapare is defined outside the block, it's a constant + // within the loop. + if (!compare || compare.getComparisonDirection() != stablehlo::ComparisonDirection::LT || + compare.getRhs().getParentBlock() == &cond || + !getElementTypeOrSelf(compare.getLhs().getType()).isSignlessIntOrIndex()) { + return std::nullopt; + } + + auto iterArg = matchBbArg(compare.getLhs(), cond); + if (!iterArg) + return std::nullopt; + + auto add = llvm::dyn_cast_or_null( + body.getTerminator()->getOperand(*iterArg).getDefiningOp()); + if (!add || matchBbArg(add.getLhs(), body) != iterArg || + add.getRhs().getParentBlock() == &body) { + return std::nullopt; + } + + ScfForBounds bounds; + bounds.ub = compare.getRhs(); + bounds.step = add.getRhs(); + bounds.lb = op->getOperand(*iterArg); + bounds.indexArgIndex = *iterArg; + return bounds; +} + +// Rewrites `stablehlo.while` to `scf.while` or `scf.for`. +struct WhileOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(stablehlo::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + auto loc = op.getLoc(); + + if (auto bounds = extractForBounds(op)) { + auto newForOp = rewriter.create( + loc, extractTensorValue(rewriter, bounds->lb), + extractTensorValue(rewriter, bounds->ub), + extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); + + rewriter.setInsertionPointToEnd(newForOp.getBody()); + // Inline while body, and only replace the mhlo.return with an scf.yield. + inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newForOp.getRegion()); + auto indexArg = newForOp.getRegion().insertArgument( + unsigned{0}, newForOp.getLowerBound().getType(), loc); + auto oldIndexArg = newForOp.getRegion().getArgument(1 + bounds->indexArgIndex); + rewriter.setInsertionPointToStart(&newForOp.getRegion().front()); + auto indexArgTensor = + rewriter.create(loc, oldIndexArg.getType(), indexArg); + oldIndexArg.replaceAllUsesWith(indexArgTensor); + + rewriter.replaceOp(op, newForOp.getResults()); + return success(); + } + + auto newWhileOp = + rewriter.create(loc, op.getResultTypes(), adaptor.getOperands()); + + // Inline while condition. The block is the same, except the boolean result + // needs to be extracted and used with an scf.condition. + rewriter.inlineRegionBefore(op.getCond(), newWhileOp.getBefore(), + newWhileOp.getBefore().end()); + auto conditionReturn = + cast(newWhileOp.getBefore().front().getTerminator()); + rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); + Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0)); + rewriter.replaceOpWithNewOp(conditionReturn, i1, + newWhileOp.getBeforeArguments()); + + // Inline while body, and only replace the mhlo.return with an scf.yield. + inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newWhileOp.getAfter()); + + rewriter.replaceOp(op, newWhileOp.getResults()); + return success(); + } +}; + +// Rewrites `mhlo.if` to `scf.if`. +struct IfOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(stablehlo::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + auto scfIf = rewriter.create(op.getLoc(), op.getResultTypes(), + extractTensorValue(rewriter, adaptor.getPred()), + /*withElseRegion=*/true); + inlineMhloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), scfIf.getThenRegion()); + inlineMhloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), scfIf.getElseRegion()); + rewriter.replaceOp(op, scfIf.getResults()); + return success(); + } +}; + +// Rewrites `mhlo.case` to a nested `scf.if`. +struct CaseOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // Recursively create if/else ops to handle each possible value in a case op. + scf::IfOp createNestedCases(int currentIdx, CaseOp op, OpAdaptor adaptor, + PatternRewriter &outerBuilder) const + { + Location loc = op.getLoc(); + Value idxValue = adaptor.getIndex(); + auto finalIdx = op.getBranches().size() - 2; + + // Determine if the current index matches the case index. + auto scalarType = idxValue.getType(); + auto shapedType = mlir::cast(scalarType); + auto constAttr = DenseElementsAttr::get( + shapedType, {mlir::cast(outerBuilder.getI32IntegerAttr(currentIdx))}); + Value currentIdxVal = + outerBuilder.create(loc, idxValue.getType(), constAttr); + + auto scfIf = outerBuilder.create( + loc, op.getResultTypes(), + extractTensorValue(outerBuilder, + outerBuilder.create( + loc, idxValue, currentIdxVal, ComparisonDirection::EQ)), + /*withElseRegion=*/true); + inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], + scfIf.getThenRegion()); + int nextIdx = currentIdx + 1; + // Don't recurse for the final default block. + if (currentIdx == static_cast(finalIdx)) { + inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], + scfIf.getElseRegion()); + } + else { + PatternRewriter::InsertionGuard guard(outerBuilder); + outerBuilder.setInsertionPointToEnd(&scfIf.getElseRegion().back()); + auto innerIf = createNestedCases(nextIdx, op, adaptor, outerBuilder); + outerBuilder.create(op.getLoc(), innerIf.getResults()); + } + return scfIf; + } + + LogicalResult matchAndRewrite(stablehlo::CaseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + // Inline the op if there is only a default block. + if (op.getBranches().size() == 1) { + Block &block = op.getBranches().front().front(); + auto results = block.getTerminator()->getOperands(); + // Remove the mhlo.return terminator, then inline the block. + rewriter.eraseOp(block.getTerminator()); + rewriter.inlineBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(), + /*argValues=*/{}); + rewriter.replaceOp(op, results); + return success(); + } + + // Begin recursion with case 0. + rewriter.replaceOp(op, createNestedCases(0, op, adaptor, rewriter).getResults()); + return success(); + } +}; + +struct StablehloLegalizeControlFlowPass + : public catalyst::impl::StablehloLegalizeControlFlowPassBase< + StablehloLegalizeControlFlowPass> { + // Perform the lowering to MLIR control flow. + void runOnOperation() override + { + func::FuncOp f = getOperation(); + MLIRContext *ctx = f.getContext(); + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + mlir::ConversionTarget target(*ctx); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + target.addIllegalOp(); + + if (failed(applyPartialConversion(f, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr catalyst::createStablehloLegalizeControlFlowPass() +{ + return std::make_unique(); +} diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 954e18e70b..55b5576fd7 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -42,7 +42,7 @@ void createHloLoweringPipeline(OpPassManager &pm) pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); //pm.addPass(stablehlo::createStablehloLegalizeToHloPass()); - //pm.addNestedPass(stablehlo::createLegalizeControlFlowPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeControlFlowPass()); // (?) pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); From c18f67888a96dd65795ae0c154e1ac8edc7f6710 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Mon, 21 Jul 2025 17:34:24 -0400 Subject: [PATCH 14/63] register stablehlo optimization passes This is needed for `stablehlo-aggressive-simplification`, which contains the lowering for stablehlo::ReduceOp --- frontend/catalyst/pipelines.py | 4 +++- mlir/lib/Driver/CMakeLists.txt | 2 ++ mlir/lib/Driver/CompilerDriver.cpp | 4 +++- mlir/tools/quantum-opt/CMakeLists.txt | 1 + mlir/tools/quantum-opt/quantum-opt.cpp | 5 +++-- 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 1adacbc350..bb792e4fe1 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -226,7 +226,9 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: "func.func(chlo-legalize-to-stablehlo)", #"stablehlo-legalize-to-hlo", "func.func(stablehlo-legalize-control-flow)", - "func.func(stablehlo-legalize-to-linalg)", + #"func.func(stablehlo-legalize-to-linalg)", + "func.func(stablehlo-aggressive-simplification)", + "stablehlo-legalize-to-linalg{enable-primitive-ops}", "func.func(stablehlo-legalize-to-std)", "func.func(stablehlo-legalize-sort)", "stablehlo-convert-to-signless", diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index b979922561..fa448ed77c 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -41,6 +41,7 @@ set(LIBS MLIRIon ion-transforms StablehloRegister + StablehloOptimizationPasses MLIRCatalystTest ${ENZYME_LIB} StablehloCAPI @@ -53,6 +54,7 @@ add_mlir_library(CatalystCompilerDriver DEPENDS StablehloBaseIncGen + OptimizationPassesIncGen LINK_LIBS PRIVATE ${LIBS} diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index beedf0695e..701dd68a93 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -24,9 +24,10 @@ #include #include -#include "stablehlo/transforms/Passes.h" #include "stablehlo/dialect/Register.h" #include "stablehlo/integrations/c/StablehloPasses.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllDialects.h" @@ -963,6 +964,7 @@ int QuantumDriverMainFromCL(int argc, char **argv) registerAllCatalystPasses(); registerAllCatalystPipelines(); mlirRegisterAllStablehloPasses(); + mlir::stablehlo::registerOptimizationPasses(); registerAllCatalystDialects(registry); registerLLVMTranslations(registry); diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 1d3eea1564..1ec9a2283d 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -22,6 +22,7 @@ set(LIBS ion-transforms StablehloRegister StablehloPasses + StablehloOptimizationPasses StablehloOps StablehloCAPI MLIRCatalystTest diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index 8b526e363d..a7b532b228 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/DialectRegistry.h" @@ -22,8 +21,9 @@ #include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" #include "stablehlo/integrations/c/StablehloPasses.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" @@ -49,6 +49,7 @@ int main(int argc, char **argv) mlir::registerAllPasses(); catalyst::registerAllCatalystPasses(); mlirRegisterAllStablehloPasses(); + mlir::stablehlo::registerOptimizationPasses(); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); From f6930d4bab97f6e3fbaddbaabef8973eb6579618 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Tue, 29 Jul 2025 13:46:33 -0400 Subject: [PATCH 15/63] . --- mlir/mlir-hlo | 1 - 1 file changed, 1 deletion(-) delete mode 160000 mlir/mlir-hlo diff --git a/mlir/mlir-hlo b/mlir/mlir-hlo deleted file mode 160000 index 1dd2e71331..0000000000 --- a/mlir/mlir-hlo +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 1dd2e71331014ae0373f6bf900ce6be393357190 From ff2a9df968648fad87a3ce429cfc5843bdb81623 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Tue, 29 Jul 2025 13:48:30 -0400 Subject: [PATCH 16/63] .. --- mlir/stablehlo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/stablehlo b/mlir/stablehlo index f1f035fea3..69d6dae46e 160000 --- a/mlir/stablehlo +++ b/mlir/stablehlo @@ -1 +1 @@ -Subproject commit f1f035fea33dcfdd7c471eb7f39174b344003117 +Subproject commit 69d6dae46e1c7de36e6e6973654754f05353cba5 From 55598e06d87d8812883439ced1914ed07f3931d4 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 17:33:19 -0400 Subject: [PATCH 17/63] burn merge breadcrumbs --- mlir/CMakeLists.txt | 3 - .../Catalyst/Transforms/CMakeLists.txt | 3 +- mlir/include/Catalyst/Transforms/Passes.td | 49 -- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 8 - .../stablehlo_legalize_control_flow.cpp | 316 --------- .../Transforms/stablehlo_legalize_sort.cpp | 608 ------------------ .../Transforms/stablehlo_legalize_to_std.cpp | 264 -------- 7 files changed, 2 insertions(+), 1249 deletions(-) delete mode 100644 mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp delete mode 100644 mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp delete mode 100644 mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 661c5ab299..c47db99ccc 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -38,9 +38,6 @@ set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) -# Taken from mlir-hlo/mhlo/transforms/CMakeLists.txt. -# Unfortunately, AllMhloPasses doesn't appear to be exported. - list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") diff --git a/mlir/include/Catalyst/Transforms/CMakeLists.txt b/mlir/include/Catalyst/Transforms/CMakeLists.txt index 1b4879ff71..b4977026e0 100644 --- a/mlir/include/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/include/Catalyst/Transforms/CMakeLists.txt @@ -3,7 +3,8 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name catalyst) add_public_tablegen_target(MLIRCatalystPassIncGen) add_mlir_doc(Passes CatalystPasses ./ -gen-pass-doc) -# The following is taken from mhlo to build the --legalize-to-std pass +# The following is taken from tensorflow/mlir-hlo repo +# to build the --legalize-to-std pass set(LLVM_TARGET_DEFINITIONS stablehlo_legalize_to_standard_patterns.td) include_directories( ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 46080b1f74..e731cecddc 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -282,53 +282,4 @@ def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> { let constructor = "mlir::bufferization::createBufferDeallocationPass()"; } -// mhlo legalize sort pass. -// mhlo dropped the -legalize-sort pass when migrating to stablehlo. -// We manually add it back. -// -// This pass has been modified from its original form in the tensorflow/mlir-hlo repository at -// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/mhlo_passes.td -// released under the Apache License, Version 2.0, with the following copyright notice: -// -// * Licensed under the Apache License, Version 2.0 (the "License"); - -def StablehloLegalizeSortPass : Pass<"stablehlo-legalize-sort", "func::FuncOp"> { - let summary = "Legalize from Stablehlo sort to SCF control flow."; - let constructor = "createStablehloLegalizeSortPass()"; - let dependentDialects = ["arith::ArithDialect", - "bufferization::BufferizationDialect", - "scf::SCFDialect", "tensor::TensorDialect"]; -} - -// mhlo legalize to std pass. -// mhlo dropped the -legalize-to-std pass when migrating to stablehlo. -// We manually add it back. -// -// This pass has been modified from its original form in the tensorflow/mlir-hlo repository at -// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/mhlo_passes.td -// released under the Apache License, Version 2.0, with the following copyright notice: -// -// * Licensed under the Apache License, Version 2.0 (the "License"); - -def StablehloLegalizeToStandardPass : Pass<"stablehlo-legalize-to-std", "func::FuncOp"> { - let summary = "Legalize from MHLO dialect to standard dialect."; - let constructor = "createStablehloLegalizeToStdPass()"; -} - -// mhlo legalize to control flow pass. -// mhlo dropped the -legalize-to-control-flow pass when migrating to stablehlo. -// We manually add it back. -// -// This pass has been modified from its original form in the tensorflow/mlir-hlo repository at -// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/mhlo_passes.td -// released under the Apache License, Version 2.0, with the following copyright notice: -// -// * Licensed under the Apache License, Version 2.0 (the "License"); - -def StablehloLegalizeControlFlowPass : Pass<"stablehlo-legalize-control-flow", "func::FuncOp"> { - let summary = "Legalize from MHLO control flow to SCF control flow."; - let constructor = "createStablehloLegalizeControlFlowPass()"; - let dependentDialects = ["scf::SCFDialect", "tensor::TensorDialect"]; -} - #endif // CATALYST_PASSES diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index b67727a900..b4776af6a9 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -25,9 +25,6 @@ file(GLOB SRC scatter_lowering.cpp ScatterPatterns.cpp SplitMultipleTapes.cpp - stablehlo_legalize_control_flow.cpp - stablehlo_legalize_sort.cpp - stablehlo_legalize_to_std.cpp TBAAPatterns.cpp TBAATagsPass.cpp ) @@ -37,15 +34,10 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} - StablehloPasses - StablehloOps ) set(DEPENDS MLIRCatalystPassIncGen - MLIRStablehloLegalizeToStandardIncGen - StablehloBaseIncGen - StablehloOpsIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) diff --git a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp deleted file mode 100644 index 3faf312560..0000000000 --- a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_control_flow.cpp +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2025 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is taken from the -// tensorflow/mlir-hlo -// repository, under the Apache 2.0 License, at -// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc -// with the following copyright notice: - -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// The modifications are porting the pass from the upstream MHLO namespace to -// catalyst namespace. - -// This file implements logic for lowering MHLO dialect to SCF dialect. -#include -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" -#include "llvm/Support/Casting.h" - -#include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/Transforms/Passes.h" - -using namespace mlir; -using namespace stablehlo; -using namespace catalyst; - -namespace catalyst { - -#define GEN_PASS_DEF_STABLEHLOLEGALIZECONTROLFLOWPASS -#define GEN_PASS_DECL_STABLEHLOLEGALIZECONTROLFLOWPASS -#include "Catalyst/Transforms/Passes.h.inc" - -} // namespace catalyst - -namespace { - -// All transformations in this file take mhlo blocks which end with -// stablehlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an -// entire block with the only change being return -> yield. -void inlineMhloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &mhlo, Region &scf) -{ - // Remove an existing block, then move the region over. - if (!scf.empty()) - rewriter.eraseBlock(&scf.back()); - rewriter.inlineRegionBefore(mhlo, scf, scf.end()); - // Fix up the terminator. - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(&scf.back()); - auto *terminator = scf.back().getTerminator(); - rewriter.replaceOpWithNewOp(terminator, terminator->getOperands()); -} - -// mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor -// or a 1 element tensor. To handle this, collapse shape before extracting the -// scalar value when necessary. -Value extractTensorValue(OpBuilder &b, Value tensor) -{ - auto loc = tensor.getLoc(); - if (mlir::cast(tensor.getType()).hasRank() && - mlir::cast(tensor.getType()).getRank() != 0) { - tensor = - b.create(loc, tensor, SmallVector()); - } - return b.create(loc, tensor, ValueRange()); -} - -struct ScfForBounds { - Value lb; - Value ub; - Value step; - unsigned indexArgIndex; -}; - -std::optional extractForBounds(stablehlo::WhileOp op) -{ - auto &cond = op.getCond().front(); - auto &body = op.getBody().front(); - if (cond.getOperations().size() != 2) - return std::nullopt; - - auto matchBbArg = [](Value v, Block &block) -> std::optional { - if (!mlir::isa(v) || v.getParentBlock() != &block) - return std::nullopt; - return mlir::cast(v).getArgNumber(); - }; - - auto compare = llvm::dyn_cast(cond.front()); - // If the rhs of the comapare is defined outside the block, it's a constant - // within the loop. - if (!compare || compare.getComparisonDirection() != stablehlo::ComparisonDirection::LT || - compare.getRhs().getParentBlock() == &cond || - !getElementTypeOrSelf(compare.getLhs().getType()).isSignlessIntOrIndex()) { - return std::nullopt; - } - - auto iterArg = matchBbArg(compare.getLhs(), cond); - if (!iterArg) - return std::nullopt; - - auto add = llvm::dyn_cast_or_null( - body.getTerminator()->getOperand(*iterArg).getDefiningOp()); - if (!add || matchBbArg(add.getLhs(), body) != iterArg || - add.getRhs().getParentBlock() == &body) { - return std::nullopt; - } - - ScfForBounds bounds; - bounds.ub = compare.getRhs(); - bounds.step = add.getRhs(); - bounds.lb = op->getOperand(*iterArg); - bounds.indexArgIndex = *iterArg; - return bounds; -} - -// Rewrites `stablehlo.while` to `scf.while` or `scf.for`. -struct WhileOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(stablehlo::WhileOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - auto loc = op.getLoc(); - - if (auto bounds = extractForBounds(op)) { - auto newForOp = rewriter.create( - loc, extractTensorValue(rewriter, bounds->lb), - extractTensorValue(rewriter, bounds->ub), - extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); - - rewriter.setInsertionPointToEnd(newForOp.getBody()); - // Inline while body, and only replace the mhlo.return with an scf.yield. - inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newForOp.getRegion()); - auto indexArg = newForOp.getRegion().insertArgument( - unsigned{0}, newForOp.getLowerBound().getType(), loc); - auto oldIndexArg = newForOp.getRegion().getArgument(1 + bounds->indexArgIndex); - rewriter.setInsertionPointToStart(&newForOp.getRegion().front()); - auto indexArgTensor = - rewriter.create(loc, oldIndexArg.getType(), indexArg); - oldIndexArg.replaceAllUsesWith(indexArgTensor); - - rewriter.replaceOp(op, newForOp.getResults()); - return success(); - } - - auto newWhileOp = - rewriter.create(loc, op.getResultTypes(), adaptor.getOperands()); - - // Inline while condition. The block is the same, except the boolean result - // needs to be extracted and used with an scf.condition. - rewriter.inlineRegionBefore(op.getCond(), newWhileOp.getBefore(), - newWhileOp.getBefore().end()); - auto conditionReturn = - cast(newWhileOp.getBefore().front().getTerminator()); - rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); - Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0)); - rewriter.replaceOpWithNewOp(conditionReturn, i1, - newWhileOp.getBeforeArguments()); - - // Inline while body, and only replace the mhlo.return with an scf.yield. - inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newWhileOp.getAfter()); - - rewriter.replaceOp(op, newWhileOp.getResults()); - return success(); - } -}; - -// Rewrites `mhlo.if` to `scf.if`. -struct IfOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(stablehlo::IfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - auto scfIf = rewriter.create(op.getLoc(), op.getResultTypes(), - extractTensorValue(rewriter, adaptor.getPred()), - /*withElseRegion=*/true); - inlineMhloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), scfIf.getThenRegion()); - inlineMhloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), scfIf.getElseRegion()); - rewriter.replaceOp(op, scfIf.getResults()); - return success(); - } -}; - -// Rewrites `mhlo.case` to a nested `scf.if`. -struct CaseOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - // Recursively create if/else ops to handle each possible value in a case op. - scf::IfOp createNestedCases(int currentIdx, CaseOp op, OpAdaptor adaptor, - PatternRewriter &outerBuilder) const - { - Location loc = op.getLoc(); - Value idxValue = adaptor.getIndex(); - auto finalIdx = op.getBranches().size() - 2; - - // Determine if the current index matches the case index. - auto scalarType = idxValue.getType(); - auto shapedType = mlir::cast(scalarType); - auto constAttr = DenseElementsAttr::get( - shapedType, {mlir::cast(outerBuilder.getI32IntegerAttr(currentIdx))}); - Value currentIdxVal = - outerBuilder.create(loc, idxValue.getType(), constAttr); - - auto scfIf = outerBuilder.create( - loc, op.getResultTypes(), - extractTensorValue(outerBuilder, - outerBuilder.create( - loc, idxValue, currentIdxVal, ComparisonDirection::EQ)), - /*withElseRegion=*/true); - inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], - scfIf.getThenRegion()); - int nextIdx = currentIdx + 1; - // Don't recurse for the final default block. - if (currentIdx == static_cast(finalIdx)) { - inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], - scfIf.getElseRegion()); - } - else { - PatternRewriter::InsertionGuard guard(outerBuilder); - outerBuilder.setInsertionPointToEnd(&scfIf.getElseRegion().back()); - auto innerIf = createNestedCases(nextIdx, op, adaptor, outerBuilder); - outerBuilder.create(op.getLoc(), innerIf.getResults()); - } - return scfIf; - } - - LogicalResult matchAndRewrite(stablehlo::CaseOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - // Inline the op if there is only a default block. - if (op.getBranches().size() == 1) { - Block &block = op.getBranches().front().front(); - auto results = block.getTerminator()->getOperands(); - // Remove the mhlo.return terminator, then inline the block. - rewriter.eraseOp(block.getTerminator()); - rewriter.inlineBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(), - /*argValues=*/{}); - rewriter.replaceOp(op, results); - return success(); - } - - // Begin recursion with case 0. - rewriter.replaceOp(op, createNestedCases(0, op, adaptor, rewriter).getResults()); - return success(); - } -}; - -struct StablehloLegalizeControlFlowPass - : public catalyst::impl::StablehloLegalizeControlFlowPassBase< - StablehloLegalizeControlFlowPass> { - // Perform the lowering to MLIR control flow. - void runOnOperation() override - { - func::FuncOp f = getOperation(); - MLIRContext *ctx = f.getContext(); - - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - - mlir::ConversionTarget target(*ctx); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addIllegalOp(); - - if (failed(applyPartialConversion(f, target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr catalyst::createStablehloLegalizeControlFlowPass() -{ - return std::make_unique(); -} diff --git a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp deleted file mode 100644 index e821369dda..0000000000 --- a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_sort.cpp +++ /dev/null @@ -1,608 +0,0 @@ -// Copyright 2025 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - -// This file is taken from the -// tensorflow/mlir-hlo -// repository, under the Apache 2.0 License, at -// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_sort/legalize_sort.cc -// with the following copyright notice: - - /* Copyright 2019 The OpenXLA Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - -// The modifications are porting the pass from the upstream MHLO namespace to -// catalyst namespace. - -// This file implements logic for lowering stablehlo.sort to the SCF dialect. -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeRange.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" - -#include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/Transforms/Passes.h" - -using namespace mlir; -using namespace stablehlo; -using namespace catalyst; - -namespace catalyst { - -#define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS -#define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS -#include "Catalyst/Transforms/Passes.h.inc" - -} // namespace catalyst - -namespace { - -using ::mlir::arith::AddIOp; -using ::mlir::arith::MinSIOp; -using ::mlir::arith::SelectOp; - -constexpr int64_t kInsertionSortSize = 16; - -// Inlines the `comparator` region (without terminator) at the current insertion -// point, replacing the arguments with the given values from `lhs` and `rhs`. -Value emitComparison(ImplicitLocOpBuilder& b, SmallVector& lhs, - SmallVector& rhs, Region& comparator) { - assert(comparator.hasOneBlock() && "Comparator must have only one block."); - Block& block = comparator.front(); - assert(block.getTerminator()->getOperands().size() == 1 && - "Comparator must return a single value"); - - IRMapping mapping; - for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { - Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; - Type type = RankedTensorType::get({}, value.getType()); - mapping.map(arg, b.create(type, value)); - } - - for (Operation& op : block.without_terminator()) b.clone(op, mapping); - Value result = mapping.lookup(block.getTerminator()->getOperands().front()); - - return b.create(result, ValueRange()); -} - -// Emits a binary search of `pivots` in `arrayMemrefs` (all rank 1) in the range -// [`left`;`right`). `arrayMemrefs` must be sorted according to `comparator`. -Value emitBinarySearch(ImplicitLocOpBuilder& b, Value leftInit, Value rightInit, - SmallVector& pivots, ValueRange arrayMemrefs, - Region& comparator) { - SmallVector types{leftInit.getType(), rightInit.getType()}; - ArithBuilder arith(b, b.getLoc()); - - // while ( - auto whileOp = - b.create(types, SmallVector{leftInit, rightInit}); - OpBuilder::InsertionGuard guard(b); - - // left < right) { - Block* before = b.createBlock(&whileOp.getBefore(), {}, types, - {whileOp.getLoc(), whileOp.getLoc()}); - { - Value left = before->getArgument(0), right = before->getArgument(1); - b.setInsertionPointToEnd(before); - b.create(arith.slt(left, right), before->getArguments()); - } - - Block* after = b.createBlock(&whileOp.getAfter(), {}, types, - {whileOp.getLoc(), whileOp.getLoc()}); - { - Value left = after->getArgument(0), right = after->getArgument(1); - b.setInsertionPointToEnd(after); - // int mid = (left + right) >> 1; - Value one = b.create(1); - Value mid = b.create(arith.add(left, right), one); - Value midPlusOne = b.create(mid, one); - - auto arraysAtMid = llvm::to_vector( - llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { - return b.create(arrayMemref, mid); - })); - Value cond = emitComparison(b, pivots, arraysAtMid, comparator); - // if (comparator(pivot, array[mid])) - // right = mid; - // else - // left = mid + 1; - Value newLeft = arith.select(cond, left, midPlusOne); - Value newRight = arith.select(cond, mid, right); - - // } - b.create(ValueRange{newLeft, newRight}); - } - - return whileOp.getResult(0); -} - -SmallVector loadTensorElements(ImplicitLocOpBuilder& b, - ValueRange tensors, Value index) { - return llvm::to_vector(llvm::map_range(tensors, [&](Value tensor) -> Value { - return b.create(tensor, index); - })); -} - -SmallVector loadMemrefElements(ImplicitLocOpBuilder& b, - ValueRange memrefs, Value index) { - return llvm::to_vector(llvm::map_range(memrefs, [&](Value memref) -> Value { - Type type = mlir::cast(memref.getType()).getElementType(); - return b.create(type, memref, index); - })); -} - -void storeMemrefElements(ImplicitLocOpBuilder& b, ValueRange memrefs, - Value index, ValueRange values) { - for (auto [value, memref] : llvm::zip(values, memrefs)) { - b.create(value, memref, index); - } -} - -// Insertion sorts `inputTensors` in the range [`lo`; `hi`), storing the results -// in `outputMemrefs`. `inputTensors` and `outputMemrefs` must all be rank 1 and -// of identical size. -void emitInsertionSort(ImplicitLocOpBuilder& b, Value lo, Value hi, - ValueRange inputTensors, ValueRange outputMemrefs, - mlir::Region& comparator) { - ArithBuilder arith(b, b.getLoc()); - Value zero = b.create(0); - Value one = b.create(1); - - // array[lo] = tensors[lo]; - storeMemrefElements(b, outputMemrefs, lo, - loadTensorElements(b, inputTensors, lo)); - - // for (int start = lo + 1; start < hi; ++start) - { - auto forOp = b.create(arith.add(lo, one), hi, one); - OpBuilder::InsertionGuard outerGuard(b); - b.setInsertionPointToStart(forOp.getBody()); - Value start = forOp.getInductionVar(); - - // T pivot = tensors[start]; - auto pivots = loadTensorElements(b, inputTensors, start); - - // int index = binarySearch(lo, start, pivot, array, comparator); - auto index = - emitBinarySearch(b, lo, start, pivots, outputMemrefs, comparator); - - // int n = start - index; // The number of elements to move - Value n = arith.sub(start, index); - - // memmove(&array[index + 1], &array[index], n * sizeof(T)) - // memref::CopyOp would be nice to use here, but: - // 1. It lowers to a quite inefficient library call in the general case - // (strides != 1). - // 2. It implements memcpy semantics, but we need memmove here. - // So we go with a loop instead. - auto copyForOp = b.create(zero, n, one); - { - OpBuilder::InsertionGuard innerGuard(b); - b.setInsertionPointToStart(copyForOp.getBody()); - Value copyLoopIndex = copyForOp.getBody()->getArgument(0); - - Value dstIndex = arith.sub(start, copyLoopIndex); - Value srcIndex = arith.sub(dstIndex, one); - storeMemrefElements(b, outputMemrefs, dstIndex, - loadMemrefElements(b, outputMemrefs, srcIndex)); - } - // array[index] = pivot; - storeMemrefElements(b, outputMemrefs, index, pivots); - } -} - -void emitMerge(ImplicitLocOpBuilder& b, Value lo, Value mid, Value hi, - ValueRange readBufs, ValueRange writeBufs, - mlir::Region& comparator) { - ArithBuilder arith(b, b.getLoc()); - // The while loop runs until we reach the end of either interval. It has three - // loop-carried variables: - // 1. current output index - // 2. current read index for interval 1 - // 3. current read index for interval 2 - SmallVector whileArgTypes{lo.getType(), lo.getType(), mid.getType()}; - SmallVector whileInitArgs{lo, lo, mid}; - SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); - - // while( - auto whileOp = b.create(whileArgTypes, whileInitArgs); - { - OpBuilder::InsertionGuard guard(b); - { - Block* before = - b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); - Value i0 = before->getArgument(1), i1 = before->getArgument(2); - b.setInsertionPointToEnd(before); - - // i0 < mid && i1 < hi) { - Value inbounds0 = arith.slt(i0, mid); - Value inbounds1 = arith.slt(i1, hi); - - b.create(arith._and(inbounds0, inbounds1), - before->getArguments()); - } - - { - Block* after = - b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); - Value iOut = after->getArgument(0), i0 = after->getArgument(1), - i1 = after->getArgument(2); - b.setInsertionPointToEnd(after); - - // auto vals0 = readBufs[i0], vals1 = readBufs[i1]; - SmallVector vals0 = loadMemrefElements(b, readBufs, i0); - SmallVector vals1 = loadMemrefElements(b, readBufs, i1); - - // writeBufs[iOut] = comparator(vals1, vals0) - // ? readBufs[i1++] : readBufs[i0++]; - Value cmp = emitComparison(b, vals1, vals0, comparator); - SmallVector pickedVals; - for (auto [val0, val1] : llvm::zip(vals0, vals1)) { - pickedVals.push_back(b.create(cmp, val1, val0)); - } - storeMemrefElements(b, writeBufs, iOut, pickedVals); - - Value one = b.create(1); - Value nexti0 = b.create(cmp, i0, arith.add(i0, one)); - Value nexti1 = b.create(cmp, arith.add(i1, one), i1); - // ++iOut; - Value nextIOut = b.create(iOut, one); - b.create(ValueRange{nextIOut, nexti0, nexti1}); - } - } - - // At this point, exactly one of the input ranges will have leftover elements. - Value iOut = whileOp->getResult(0); - Value i0 = whileOp->getResult(1); - Value i1 = whileOp->getResult(2); - - // We could use memref::CopyOp here, but typically, there aren't many leftover - // elements for randomly shuffled inputs. - Value leftoverIn0 = arith.slt(i0, mid); - Value start = arith.select(leftoverIn0, i0, i1); - Value end = arith.select(leftoverIn0, mid, hi); - Value n = arith.sub(end, start); - - Value zero = b.create(0); - Value one = b.create(1); - auto forOp = b.create(zero, n, one); - b.setInsertionPointToStart(forOp.getBody()); - Value copyIndex = forOp.getBody()->getArgument(0); - - Value srcIndex = arith.add(start, copyIndex); - Value dstIndex = arith.add(iOut, copyIndex); - storeMemrefElements(b, writeBufs, dstIndex, - loadMemrefElements(b, readBufs, srcIndex)); -} - -// Emits a bottom up merge sort of `inputTensors` in the range [`lo`; `hi`), and -// writes the results to either `outputs0` or `outputs1`. -// Returns 0 if the results are in `outputs0`, 1 if they are in `outputs1`. -// TODO(jreiffers): Consider implementing top-down merge sort. -Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, - int64_t staticSortDimSize, ValueRange inputTensors, - ValueRange outputs0, ValueRange outputs1, - mlir::Region& comparator) { - ArithBuilder arith(b, b.getLoc()); - Value size = arith.sub(hi, lo); - - Value zero = b.create(0); - Value insertionSortSize = - b.create(kInsertionSortSize); - - // Run insertion sort on blocks of size kInsertionSortSize. - // for (int start = 0; start < size; start += kInsertionSortSize) { - { - auto forOp = b.create(zero, size, insertionSortSize); - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(forOp.getBody()); - Value start = forOp.getBody()->getArgument(0); - Value end = arith.add( - b.create(arith.add(start, insertionSortSize), size), lo); - emitInsertionSort(b, start, end, inputTensors, outputs0, comparator); - } - - Value initParity = b.create(0, 1); - if (staticSortDimSize >= 0 && staticSortDimSize < kInsertionSortSize) { - return initParity; - } - - // The while arguments are: - // 1. the current size - // 2. the original index of the buffers we're currently reading from - // 3. the buffers we're currently reading from - // 4. the buffers we're currently writing to. - // - // 1 gets doubled each iteration, 2 gets negated, 3 and 4 are swapped. - // int currentSize = 16; - SmallVector whileInitArgs{insertionSortSize, initParity}; - // First we read from `outputs0` (initialized by the insertion sort above). - llvm::copy(outputs0, std::back_inserter(whileInitArgs)); - llvm::copy(outputs1, std::back_inserter(whileInitArgs)); - - SmallVector whileArgTypes; - for (auto val : whileInitArgs) whileArgTypes.push_back(val.getType()); - - SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); - - // while ( - auto whileOp = b.create(whileArgTypes, whileInitArgs); - OpBuilder::InsertionGuard guard(b); - - // currentSize < totalSize) - { - Block* before = - b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); - Value currentSize = before->getArgument(0); - b.setInsertionPointToEnd(before); - b.create(arith.slt(currentSize, size), - before->getArguments()); - } - - size_t numArgs = inputTensors.size(); - // { - { - Block* after = - b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); - - Value currentSize = after->getArgument(0); - Value parity = after->getArgument(1); - auto readBufs = after->getArguments().drop_front(2).take_front(numArgs); - auto writeBufs = after->getArguments().take_back(numArgs); - - Value twoCurrentSize = arith.add(currentSize, currentSize); - - // for (int start = 0; start < size; start += 2*currentSize) { - { - auto forOp = b.create(zero, size, twoCurrentSize); - b.setInsertionPointToStart(forOp.getBody()); - Value start = forOp.getBody()->getArgument(0); - - Value mid = b.create(size, arith.add(start, currentSize)); - Value end = b.create(size, arith.add(start, twoCurrentSize)); - emitMerge(b, start, mid, end, readBufs, writeBufs, comparator); - b.setInsertionPointAfter(forOp); - } - // } - - // parity = !parity; - Value one = b.create(1, 1); - Value notParity = arith.sub(one, parity); - // currentSize *= 2; - SmallVector nextWhileArgs{twoCurrentSize, notParity}; - llvm::copy(writeBufs, std::back_inserter(nextWhileArgs)); - llvm::copy(readBufs, std::back_inserter(nextWhileArgs)); - b.create(nextWhileArgs); - } - // } - - // The result is the parity bit. - return whileOp.getResults().drop_front(1).front(); -} - -// Helper struct for extracting 1d slices from tensors and memrefs. -struct Slicer { - Slicer(OpBuilder& b, uint64_t sortDim, Value sortDimSize, ValueRange ivs) - : sizes(ivs.size() + 1, b.getI64IntegerAttr(1)), - strides(ivs.size() + 1, b.getI64IntegerAttr(1)) { - sizes[sortDim] = sortDimSize; - for (size_t i = 0; i < ivs.size() + 1; ++i) { - if (i == sortDim) { - offsets.push_back(b.getI64IntegerAttr(0)); - } else { - offsets.push_back(ivs[i - static_cast(i > sortDim)]); - } - } - } - - RankedTensorType toSlicedType(RankedTensorType sourceType) { - return tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - /*resultRank=*/1, sourceType, offsets, sizes, strides); - } - - MemRefType toSlicedType(MemRefType sourceType) { - return mlir::cast(memref::SubViewOp::inferRankReducedResultType( - {ShapedType::kDynamic} /*1D output*/, sourceType, offsets, sizes, - strides)); - } - - template - Value slice(ImplicitLocOpBuilder& b, Value input) { - Ty ty = mlir::cast(input.getType()); - return b.create(toSlicedType(ty), input, offsets, sizes, strides) - .getResult(); - } - - Value apply(ImplicitLocOpBuilder& b, Value input) { - Type inTy = input.getType(); - if (mlir::isa(inTy)) { - return slice(b, input); - } - assert(mlir::isa(inTy)); - return slice(b, input); - } - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; -}; - -SmallVector sliceMemrefsOrTensors(ImplicitLocOpBuilder& b, - SmallVector& ivs, - Value sortDimSize, - ValueRange memrefsOrTensors, - SortOp op) { - if (ivs.empty()) return memrefsOrTensors; - - SmallVector outputs; - Slicer slicer(b, op.getDimension(), sortDimSize, ivs); - // Create subviews/slices. - for (Value out : memrefsOrTensors) { - outputs.push_back(slicer.apply(b, out)); - } - - return outputs; -} - -struct SortOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SortOp op, - PatternRewriter& rewriter) const override { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Note: the output memrefs aren't necessarily the ones that we return, - SmallVector outputMemrefs; - SmallVector scratchMemrefs; - - Value firstOperand = op.getOperands().front(); - auto firstOperandType = mlir::cast(firstOperand.getType()); - int64_t inputRank = firstOperandType.getRank(); - - Value sortDimSize = b.createOrFold( - firstOperand, b.create(op.getDimension())); - int64_t staticSortDimSize = firstOperandType.getDimSize(op.getDimension()); - - SmallVector dynamicDims; - for (int i = 0; i < inputRank; ++i) { - if (!firstOperandType.isDynamicDim(i)) continue; - Value index = b.create(i); - Value dimOp = b.create(firstOperand, index); - dynamicDims.push_back(dimOp); - } - - // Allocate output and scratch memrefs. If the size of the sort dimension is - // statically known to be <= kInsertionSortSize, `scratchMemrefs` are unused - // and will be cleaned up later. - for (auto input : op.getOperands()) { - auto inputType = mlir::cast(input.getType()); - auto memRefType = - MemRefType::get(inputType.getShape(), inputType.getElementType()); - - outputMemrefs.push_back( - b.create(memRefType, dynamicDims)); - scratchMemrefs.push_back( - b.create(memRefType, dynamicDims)); - } - - b.setInsertionPoint(op); - Value zero = b.create(0); - Value one = b.create(1); - - Value forInitArg = b.create(0, 1); - SmallVector forOps; - SmallVector ivs; - forOps.reserve(inputRank - 1); - ivs.reserve(inputRank - 1); - for (int64_t i = 0; i < inputRank; ++i) { - if (i != static_cast(op.getDimension())) { - Value dim = b.create(i); - Value ub = b.create(firstOperand, dim); - scf::ForOp& forOp = forOps.emplace_back( - b.create(zero, ub, one, ValueRange{forInitArg})); - ivs.push_back(forOp.getInductionVar()); - b.setInsertionPointToStart(&forOp.getRegion().front()); - } - } - SmallVector inputs = - sliceMemrefsOrTensors(b, ivs, sortDimSize, op.getOperands(), op); - SmallVector outputs = - sliceMemrefsOrTensors(b, ivs, sortDimSize, outputMemrefs, op); - SmallVector scratches = - sliceMemrefsOrTensors(b, ivs, sortDimSize, scratchMemrefs, op); - - Value parity = - emitBottomUpMergeSort(b, zero, sortDimSize, staticSortDimSize, inputs, - outputs, scratches, op.getRegion()); - - // Pass the parity bit through the for loops. - for (auto i = static_cast(forOps.size() - 1); i >= 0; --i) { - b.setInsertionPointToEnd(&forOps[i].getRegion().front()); - b.create(ValueRange{parity}); - parity = forOps[i]->getResult(0); - } - b.setInsertionPoint(op); - - SmallVector outputTensors; - for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) { - outputTensors.push_back(b.create( - b.create(parity, out1, out0), /*restrict=*/true)); - } - - rewriter.replaceOp(op, outputTensors); - return success(); - } -}; - -struct StablehloLegalizeSortPass - : public catalyst::impl::StablehloLegalizeSortPassBase { - // Perform the lowering to MLIR control flow. - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext* ctx = f.getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(ctx); - - mlir::ConversionTarget target(*ctx); - target.markUnknownOpDynamicallyLegal([](Operation*) { return true; }); - target.addIllegalOp(); - - if (failed(applyPartialConversion(f, target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr catalyst::createStablehloLegalizeSortPass() -{ - return std::make_unique(); -} diff --git a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp b/mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp deleted file mode 100644 index 371696c740..0000000000 --- a/mlir/lib/Catalyst/Transforms/stablehlo_legalize_to_std.cpp +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2025 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is taken from the -// tensorflow/mlir-hlo -// repository, under the Apache 2.0 License, at -// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc -// with the following copyright notice: - -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// The modifications are porting the pass from the upstream MHLO namespace to -// catalyst namespace. - -// This file implements logic for lowering MHLO dialect to Standard dialect. - -#include -#include -#include - -// #include "mhlo/transforms/rewriters.h" // (??) -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" - -#include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/Transforms/Passes.h" - -using namespace mlir; -using namespace stablehlo; -using namespace catalyst; - -namespace catalyst { - -#define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS -#define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS -#include "Catalyst/Transforms/Passes.h.inc" -#include "Catalyst/Transforms/generated_stablehlo_legalize_to_standard.cpp.inc" - -} // namespace catalyst - -namespace { - -class CompareIConvert : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override - { - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto lhsType = mlir::cast(lhs.getType()); - auto rhsType = mlir::cast(rhs.getType()); - - // Broadcasting not supported by this rewrite. - if (lhsType.getShape() != rhsType.getShape()) - return failure(); - - if (!lhsType.getElementType().isSignlessInteger() || - !rhsType.getElementType().isSignlessInteger()) - return failure(); - - std::optional comparePredicate = std::nullopt; - switch (op.getComparisonDirection()) { - case ComparisonDirection::EQ: - comparePredicate = arith::CmpIPredicate::eq; - break; - case ComparisonDirection::NE: - comparePredicate = arith::CmpIPredicate::ne; - break; - case ComparisonDirection::LT: - comparePredicate = arith::CmpIPredicate::slt; - break; - case ComparisonDirection::LE: - comparePredicate = arith::CmpIPredicate::sle; - break; - case ComparisonDirection::GT: - comparePredicate = arith::CmpIPredicate::sgt; - break; - case ComparisonDirection::GE: - comparePredicate = arith::CmpIPredicate::sge; - break; - } - - if (!comparePredicate.has_value()) - return failure(); - - rewriter.replaceOpWithNewOp(op, comparePredicate.value(), lhs, rhs); - return success(); - } -}; - -class CompareFConvert : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override - { - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto lhsType = mlir::cast(lhs.getType()); - auto rhsType = mlir::cast(rhs.getType()); - - // Broadcasting not supported by this rewrite. - if (lhsType.getShape() != rhsType.getShape()) - return failure(); - - if (!mlir::isa(lhsType.getElementType()) || - !mlir::isa(rhsType.getElementType())) - return failure(); - - std::optional comparePredicate = std::nullopt; - switch (op.getComparisonDirection()) { - case ComparisonDirection::EQ: - comparePredicate = arith::CmpFPredicate::OEQ; - break; - case ComparisonDirection::NE: - comparePredicate = arith::CmpFPredicate::UNE; - break; - case ComparisonDirection::LT: - comparePredicate = arith::CmpFPredicate::OLT; - break; - case ComparisonDirection::LE: - comparePredicate = arith::CmpFPredicate::OLE; - break; - case ComparisonDirection::GT: - comparePredicate = arith::CmpFPredicate::OGT; - break; - case ComparisonDirection::GE: - comparePredicate = arith::CmpFPredicate::OGE; - break; - } - - if (!comparePredicate.has_value()) - return failure(); - - rewriter.replaceOpWithNewOp(op, comparePredicate.value(), lhs, rhs); - return success(); - } -}; - -// Replace IotaOp with an integer constant. A ConvertOp is added to -// convert the integer constant to iota result type. For complex types, the real -// part is replaced with the generated constant and the imaginary part is -// replaced with zero tensor. -class ConvertIotaOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(stablehlo::IotaOp op, PatternRewriter &rewriter) const override - { - auto outputType = mlir::cast(op.getType()); - auto outputSize = outputType.getNumElements(); - auto dimension = op.getIotaDimension(); - auto maxDimSize = outputType.getDimSize(dimension); - - auto elementType = outputType.getElementType(); - int bitwidth; - - auto complexTy = mlir::dyn_cast(elementType); - Type intOrFloatTy = elementType; - if (complexTy) - intOrFloatTy = complexTy.getElementType(); - - bitwidth = intOrFloatTy.getIntOrFloatBitWidth(); - llvm::SmallVector values; - values.reserve(outputSize); - - int64_t increaseStride = outputSize; - for (uint64_t i = 0; i <= dimension; i++) { - increaseStride /= outputType.getDimSize(i); - } - - int64_t currentValue = 0; - for (int i = 0; i < outputSize; i++) { - int64_t value = (currentValue / increaseStride) % maxDimSize; - values.push_back(APInt(bitwidth, value)); - ++currentValue; - } - - auto intShapeType = RankedTensorType::get( - outputType.getShape(), IntegerType::get(rewriter.getContext(), bitwidth)); - auto loc = op.getLoc(); - auto integerConst = rewriter.create( - loc, DenseIntElementsAttr::get(intShapeType, values)); - - auto intOrFloatShapeTy = RankedTensorType::get(outputType.getShape(), intOrFloatTy); - - auto iotaConst = rewriter.create(loc, intOrFloatShapeTy, integerConst); - - // For int/float types we are done, replace op and return. - if (!complexTy) { - rewriter.replaceOp(op, iotaConst.getResult()); - return success(); - } - - // For complex types, generate a constant tensor of zeroes for the imaginary - // part and use iota_const for real part. - auto zeroes = rewriter.create( - loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0))); - auto imagZeroes = rewriter.create(loc, intOrFloatShapeTy, zeroes); - rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); - return success(); - } -}; - -void populateStablehloToStdPatterns(RewritePatternSet *patterns, mlir::MLIRContext *ctx) -{ - populateWithGenerated(*patterns); - patterns->add(ctx); -} - -struct StablehloLegalizeToStandardPass - : public catalyst::impl::StablehloLegalizeToStandardPassBase { - void getDependentDialects(DialectRegistry ®istry) const override - { - registry.insert(); - } - - /// Perform the lowering to Standard dialect. - void runOnOperation() override - { - RewritePatternSet patterns(&getContext()); - populateStablehloToStdPatterns(&patterns, &getContext()); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) - return signalPassFailure(); - } -}; -} // end anonymous namespace - -std::unique_ptr catalyst::createStablehloLegalizeToStdPass() -{ - return std::make_unique(); -} From 41abbb3eb89276f34c6777a82020cff266ce8367 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 17:38:33 -0400 Subject: [PATCH 18/63] cleanup breadcrumbs --- mlir/include/CMakeLists.txt | 6 +++--- mlir/include/{mlir-hlo => stablehlo}/CMakeLists.txt | 0 mlir/include/{mlir-hlo => stablehlo}/Passes.h | 0 mlir/include/{mlir-hlo => stablehlo}/Passes.td | 0 .../mhlo_legalize_to_standard_patterns.td | 0 mlir/lib/CMakeLists.txt | 8 ++++---- mlir/lib/{mlir-hlo => stablehlo}/CMakeLists.txt | 0 .../mhlo_legalize_control_flow.cpp | 0 mlir/lib/{mlir-hlo => stablehlo}/mhlo_legalize_sort.cpp | 0 mlir/lib/{mlir-hlo => stablehlo}/mhlo_legalize_to_std.cpp | 0 10 files changed, 7 insertions(+), 7 deletions(-) rename mlir/include/{mlir-hlo => stablehlo}/CMakeLists.txt (100%) rename mlir/include/{mlir-hlo => stablehlo}/Passes.h (100%) rename mlir/include/{mlir-hlo => stablehlo}/Passes.td (100%) rename mlir/include/{mlir-hlo => stablehlo}/mhlo_legalize_to_standard_patterns.td (100%) rename mlir/lib/{mlir-hlo => stablehlo}/CMakeLists.txt (100%) rename mlir/lib/{mlir-hlo => stablehlo}/mhlo_legalize_control_flow.cpp (100%) rename mlir/lib/{mlir-hlo => stablehlo}/mhlo_legalize_sort.cpp (100%) rename mlir/lib/{mlir-hlo => stablehlo}/mhlo_legalize_to_std.cpp (100%) diff --git a/mlir/include/CMakeLists.txt b/mlir/include/CMakeLists.txt index 6f6a692c66..4dc4680506 100644 --- a/mlir/include/CMakeLists.txt +++ b/mlir/include/CMakeLists.txt @@ -1,9 +1,9 @@ add_subdirectory(Catalyst) -add_subdirectory(Quantum) -add_subdirectory(QEC) add_subdirectory(Gradient) add_subdirectory(Ion) add_subdirectory(MBQC) add_subdirectory(Mitigation) +add_subdirectory(QEC) +add_subdirectory(Quantum) +add_subdirectory(stablehlo) add_subdirectory(Test) -add_subdirectory(mlir-hlo) diff --git a/mlir/include/mlir-hlo/CMakeLists.txt b/mlir/include/stablehlo/CMakeLists.txt similarity index 100% rename from mlir/include/mlir-hlo/CMakeLists.txt rename to mlir/include/stablehlo/CMakeLists.txt diff --git a/mlir/include/mlir-hlo/Passes.h b/mlir/include/stablehlo/Passes.h similarity index 100% rename from mlir/include/mlir-hlo/Passes.h rename to mlir/include/stablehlo/Passes.h diff --git a/mlir/include/mlir-hlo/Passes.td b/mlir/include/stablehlo/Passes.td similarity index 100% rename from mlir/include/mlir-hlo/Passes.td rename to mlir/include/stablehlo/Passes.td diff --git a/mlir/include/mlir-hlo/mhlo_legalize_to_standard_patterns.td b/mlir/include/stablehlo/mhlo_legalize_to_standard_patterns.td similarity index 100% rename from mlir/include/mlir-hlo/mhlo_legalize_to_standard_patterns.td rename to mlir/include/stablehlo/mhlo_legalize_to_standard_patterns.td diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index e9043f25ed..f669d222c8 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -1,11 +1,11 @@ -add_subdirectory(Driver) add_subdirectory(CAPI) add_subdirectory(Catalyst) -add_subdirectory(Quantum) -add_subdirectory(QEC) +add_subdirectory(Driver) add_subdirectory(Gradient) add_subdirectory(Ion) add_subdirectory(MBQC) add_subdirectory(Mitigation) +add_subdirectory(QEC) +add_subdirectory(Quantum) +add_subdirectory(stablehlo) add_subdirectory(Test) -add_subdirectory(mlir-hlo) diff --git a/mlir/lib/mlir-hlo/CMakeLists.txt b/mlir/lib/stablehlo/CMakeLists.txt similarity index 100% rename from mlir/lib/mlir-hlo/CMakeLists.txt rename to mlir/lib/stablehlo/CMakeLists.txt diff --git a/mlir/lib/mlir-hlo/mhlo_legalize_control_flow.cpp b/mlir/lib/stablehlo/mhlo_legalize_control_flow.cpp similarity index 100% rename from mlir/lib/mlir-hlo/mhlo_legalize_control_flow.cpp rename to mlir/lib/stablehlo/mhlo_legalize_control_flow.cpp diff --git a/mlir/lib/mlir-hlo/mhlo_legalize_sort.cpp b/mlir/lib/stablehlo/mhlo_legalize_sort.cpp similarity index 100% rename from mlir/lib/mlir-hlo/mhlo_legalize_sort.cpp rename to mlir/lib/stablehlo/mhlo_legalize_sort.cpp diff --git a/mlir/lib/mlir-hlo/mhlo_legalize_to_std.cpp b/mlir/lib/stablehlo/mhlo_legalize_to_std.cpp similarity index 100% rename from mlir/lib/mlir-hlo/mhlo_legalize_to_std.cpp rename to mlir/lib/stablehlo/mhlo_legalize_to_std.cpp From c8ba9a0e4fff395a72aedafc9d01bb15e03be743 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 17:40:24 -0400 Subject: [PATCH 19/63] more cleaning --- .../Catalyst/Transforms/CMakeLists.txt | 8 -- ...stablehlo_legalize_to_standard_patterns.td | 119 ------------------ 2 files changed, 127 deletions(-) delete mode 100644 mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td diff --git a/mlir/include/Catalyst/Transforms/CMakeLists.txt b/mlir/include/Catalyst/Transforms/CMakeLists.txt index b4977026e0..52802b77fc 100644 --- a/mlir/include/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/include/Catalyst/Transforms/CMakeLists.txt @@ -2,11 +2,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name catalyst) add_public_tablegen_target(MLIRCatalystPassIncGen) add_mlir_doc(Passes CatalystPasses ./ -gen-pass-doc) - -# The following is taken from tensorflow/mlir-hlo repo -# to build the --legalize-to-std pass -set(LLVM_TARGET_DEFINITIONS stablehlo_legalize_to_standard_patterns.td) -include_directories( - ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) -mlir_tablegen(generated_stablehlo_legalize_to_standard.cpp.inc -gen-rewriters) -add_public_tablegen_target(MLIRStablehloLegalizeToStandardIncGen) diff --git a/mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td b/mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td deleted file mode 100644 index a26ecde9cb..0000000000 --- a/mlir/include/Catalyst/Transforms/stablehlo_legalize_to_standard_patterns.td +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2025 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - -// This file is taken from the -// tensorflow/mlir-hlo -// repository, under the Apache 2.0 License, at -// https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td -// with the following copyright notice: - - /* Copyright 2019 The OpenXLA Authors. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - - - -// This is the legalization pattern definition file for MHLO to StandardOps. - -include "mlir/IR/OpBase.td" -include "mlir/Dialect/Arith/IR/ArithOps.td" -include "mlir/Dialect/Math/IR/MathOps.td" -include "mlir/Dialect/Func/IR/FuncOps.td" -include "stablehlo/dialect/StablehloOps.td" - -//===----------------------------------------------------------------------===// -// Nullary op patterns. -//===----------------------------------------------------------------------===// - -def : Pat<(StableHLO_ConstantOp ElementsAttr:$value), - (Arith_ConstantOp $value)>; - -//===----------------------------------------------------------------------===// -// Binary op patterns. -//===----------------------------------------------------------------------===// - -def IsSameSizePred : CPred< - "cast($0.getType()).getShape() " - "== cast($1.getType()).getShape()">; -def IsSameSizeConstraint : Constraint; -def createFastMathNone : NativeCodeCall< - "::mlir::arith::FastMathFlagsAttr::get(" - "$_builder.getContext(), ::mlir::arith::FastMathFlags::none" - ")">; -def createOverflowNone : NativeCodeCall< - "::mlir::arith::IntegerOverflowFlagsAttr::get(" - "$_builder.getContext(), ::mlir::arith::IntegerOverflowFlags::none" - ")">; -def createDenormalIEEE : NativeCodeCall< - "::mlir::arith::DenormalModeAttr::get(" - "$_builder.getContext(), ::mlir::arith::DenormalMode::ieee" - ")">; - - -// Unary Lowering Patterns. -def : Pat<(StableHLO_CeilOp HLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; - -// Binary Lowering Patterns. -def : Pat<(StableHLO_AndOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_AndIOp $l, $r), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_OrOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_OrIOp $l, $r), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), - (Arith_AddFOp $l, $r, (createFastMathNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_SubtractOp HLO_FpTensor:$l, HLO_FpTensor:$r), - (Arith_SubFOp $l, $r, (createFastMathNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), - (Arith_MulFOp $l, $r, (createFastMathNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), - (Arith_DivFOp $l, $r, (createFastMathNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), - (Arith_RemFOp $l, $r, (createFastMathNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_AddIOp $l, $r, (createOverflowNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_SubtractOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_SubIOp $l, $r, (createOverflowNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_MulIOp $l, $r, (createOverflowNone )), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_DivSIOp $l, $r), - [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(StableHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), - (Arith_RemSIOp $l, $r), - [(IsSameSizeConstraint $l, $r)]>; - -def : Pat<(StableHLO_SelectOp $pred, $tv, $fv), - (SelectOp $pred, $tv, $fv), - [(IsSameSizeConstraint $pred, $tv), (IsSameSizeConstraint $tv, $fv)]>; From fc8afd26863f7dac07f4fb058896ea012456a3f8 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 17:54:50 -0400 Subject: [PATCH 20/63] more breadcrumbs --- mlir/include/Catalyst/Transforms/Passes.h | 3 -- mlir/include/stablehlo/CMakeLists.txt | 12 +++--- mlir/include/stablehlo/Passes.h | 6 +-- mlir/include/stablehlo/Passes.td | 30 +++++++-------- ...tablehlo_legalize_to_standard_patterns.td} | 37 +++++++++---------- mlir/lib/CAPI/CMakeLists.txt | 2 +- .../Catalyst/Transforms/RegisterAllPasses.cpp | 6 +-- mlir/lib/Driver/Pipelines.cpp | 6 +-- mlir/lib/stablehlo/CMakeLists.txt | 12 +++--- ...pp => stablehlo_legalize_control_flow.cpp} | 0 ...e_sort.cpp => stablehlo_legalize_sort.cpp} | 0 ..._std.cpp => stablehlo_legalize_to_std.cpp} | 0 mlir/tools/catalyst-cli/CMakeLists.txt | 2 +- mlir/tools/quantum-opt/CMakeLists.txt | 2 +- 14 files changed, 56 insertions(+), 62 deletions(-) rename mlir/include/stablehlo/{mhlo_legalize_to_standard_patterns.td => stablehlo_legalize_to_standard_patterns.td} (77%) rename mlir/lib/stablehlo/{mhlo_legalize_control_flow.cpp => stablehlo_legalize_control_flow.cpp} (100%) rename mlir/lib/stablehlo/{mhlo_legalize_sort.cpp => stablehlo_legalize_sort.cpp} (100%) rename mlir/lib/stablehlo/{mhlo_legalize_to_std.cpp => stablehlo_legalize_to_std.cpp} (100%) diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index 82957a6201..f7872961ad 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -37,9 +37,6 @@ std::unique_ptr createQnodeToAsyncLoweringPass(); std::unique_ptr createRegisterInactiveCallbackPass(); std::unique_ptr createScatterLoweringPass(); std::unique_ptr createSplitMultipleTapesPass(); -std::unique_ptr createStablehloLegalizeSortPass(); -std::unique_ptr createStablehloLegalizeToStdPass(); -std::unique_ptr createStablehloLegalizeControlFlowPass(); void registerAllCatalystPasses(); diff --git a/mlir/include/stablehlo/CMakeLists.txt b/mlir/include/stablehlo/CMakeLists.txt index 584f1e2617..78bc778a6d 100644 --- a/mlir/include/stablehlo/CMakeLists.txt +++ b/mlir/include/stablehlo/CMakeLists.txt @@ -1,14 +1,14 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name mlir-hlo) -add_public_tablegen_target(MLIRHLOCatalystPassIncGen) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name stablehlo) +add_public_tablegen_target(STABLEHLOCatalystPassIncGen) # The following is modified from the # tensorflow/mlir-hlo # repository at # https://github.com/tensorflow/mlir-hlo/blob/a5529d99fc4d1132b0c282a053d26c11e6636b3a/mhlo/transforms/CMakeLists.txt -# to build the rewrite patterns for the --mhlo-legalize-to-std pass -set(LLVM_TARGET_DEFINITIONS mhlo_legalize_to_standard_patterns.td) +# to build the rewrite patterns for the --stablehlo-legalize-to-std pass +set(LLVM_TARGET_DEFINITIONS stablehlo_legalize_to_standard_patterns.td) include_directories( ${CATALYST_MAIN_INCLUDE_DIR}/../mlir-hlo) -mlir_tablegen(generated_mhlo_legalize_to_standard.cpp.inc -gen-rewriters) -add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) +mlir_tablegen(generated_stablehlo_legalize_to_standard.cpp.inc -gen-rewriters) +add_public_tablegen_target(MLIRStablehloLegalizeToStandardIncGen) diff --git a/mlir/include/stablehlo/Passes.h b/mlir/include/stablehlo/Passes.h index 6ce80c494e..260c9f9ba1 100644 --- a/mlir/include/stablehlo/Passes.h +++ b/mlir/include/stablehlo/Passes.h @@ -19,7 +19,7 @@ #include "mlir/Pass/Pass.h" namespace catalyst { - std::unique_ptr createMhloLegalizeSortPass(); - std::unique_ptr createMhloLegalizeToStdPass(); - std::unique_ptr createMhloLegalizeControlFlowPass(); + std::unique_ptr createStablehloLegalizeSortPass(); + std::unique_ptr createStablehloLegalizeToStdPass(); + std::unique_ptr createStablehloLegalizeControlFlowPass(); } diff --git a/mlir/include/stablehlo/Passes.td b/mlir/include/stablehlo/Passes.td index b4b7791018..64fd653ed1 100644 --- a/mlir/include/stablehlo/Passes.td +++ b/mlir/include/stablehlo/Passes.td @@ -31,31 +31,31 @@ // limitations under the License. // ==============================================================================*/ -#ifndef CATALYST_MLIRHLO_PASSES -#define CATALYST_MLIRHLO_PASSES +#ifndef CATALYST_STABLEHLO_PASSES +#define CATALYST_STABLEHLO_PASSES include "mlir/Pass/PassBase.td" -// mhlo legalize sort pass. -def MhloLegalizeSortPass : Pass<"mhlo-legalize-sort", "func::FuncOp"> { - let summary = "Legalize from Mhlo sort to SCF control flow."; - let constructor = "createMhloLegalizeSortPass()"; +// stablehlo legalize sort pass. +def StablehloLegalizeSortPass : Pass<"stablehlo-legalize-sort", "func::FuncOp"> { + let summary = "Legalize from Stablehlo sort to SCF control flow."; + let constructor = "createStablehloLegalizeSortPass()"; let dependentDialects = ["arith::ArithDialect", "bufferization::BufferizationDialect", "scf::SCFDialect", "tensor::TensorDialect"]; } -// mhlo legalize to std pass. -def MhloLegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "func::FuncOp"> { - let summary = "Legalize from MHLO dialect to standard dialect."; - let constructor = "createMhloLegalizeToStdPass()"; +// stablehlo legalize to std pass. +def StablehloLegalizeToStandardPass : Pass<"stablehlo-legalize-to-std", "func::FuncOp"> { + let summary = "Legalize from Stablehlo dialect to standard dialect."; + let constructor = "createStablehloLegalizeToStdPass()"; } -// mhlo legalize to control flow pass. -def MhloLegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "func::FuncOp"> { - let summary = "Legalize from MHLO control flow to SCF control flow."; - let constructor = "createMhloLegalizeControlFlowPass()"; +// stablehlo legalize to control flow pass. +def StablehloLegalizeControlFlowPass : Pass<"stablehlo-legalize-control-flow", "func::FuncOp"> { + let summary = "Legalize from Stablehlo control flow to SCF control flow."; + let constructor = "createStablehloLegalizeControlFlowPass()"; let dependentDialects = ["scf::SCFDialect", "tensor::TensorDialect"]; } -#endif // CATALYST_MLIRHLO_PASSES +#endif // CATALYST_STABLEHLO_PASSES diff --git a/mlir/include/stablehlo/mhlo_legalize_to_standard_patterns.td b/mlir/include/stablehlo/stablehlo_legalize_to_standard_patterns.td similarity index 77% rename from mlir/include/stablehlo/mhlo_legalize_to_standard_patterns.td rename to mlir/include/stablehlo/stablehlo_legalize_to_standard_patterns.td index a8b365bc18..ffeddb9e93 100644 --- a/mlir/include/stablehlo/mhlo_legalize_to_standard_patterns.td +++ b/mlir/include/stablehlo/stablehlo_legalize_to_standard_patterns.td @@ -36,22 +36,19 @@ -// This is the legalization pattern definition file for MHLO to StandardOps. +// This is the legalization pattern definition file for Stablehlo to StandardOps. include "mlir/IR/OpBase.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Math/IR/MathOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" -include "mhlo/IR/hlo_ops.td" -// TODO: change the above mhlo include line to the following when migrating to stablehlo -//include "stablehlo/dialect/StablehloOps.td" +include "stablehlo/dialect/StablehloOps.td" //===----------------------------------------------------------------------===// // Nullary op patterns. //===----------------------------------------------------------------------===// -// TODO: update `MHLO_BlahOp` to `StableHLO_BlahOp` when migrating to stablehlo. -def : Pat<(MHLO_ConstantOp ElementsAttr:$value), +def : Pat<(StableHLO_ConstantOp ElementsAttr:$value), (Arith_ConstantOp $value)>; //===----------------------------------------------------------------------===// @@ -77,46 +74,46 @@ def createDenormalIEEE : NativeCodeCall< // Unary Lowering Patterns. -def : Pat<(MHLO_CeilOp HLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; +def : Pat<(StableHLO_CeilOp HLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; // Binary Lowering Patterns. -def : Pat<(MHLO_AndOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_AndOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_AndIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_OrOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_OrOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_OrIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_AddFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_SubtractOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_SubtractOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_SubFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_MulFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_DivFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(StableHLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), (Arith_RemFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_AddIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_SubtractOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_SubtractOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_SubIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_MulIOp $l, $r, (createOverflowNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_DivSIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(StableHLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), (Arith_RemSIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(MHLO_SelectOp $pred, $tv, $fv), +def : Pat<(StableHLO_SelectOp $pred, $tv, $fv), (SelectOp $pred, $tv, $fv), [(IsSameSizeConstraint $pred, $tv), (IsSameSizeConstraint $tv, $fv)]>; diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 949f171567..d6afa814aa 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -4,7 +4,7 @@ add_mlir_public_c_api_library(QuantumCAPI LINK_LIBS PRIVATE MLIRCatalyst catalyst-transforms - catalyst-mhlo-transforms + catalyst-stablehlo-transforms MLIRQuantum quantum-transforms MLIRQEC diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 58ee131f45..8c9609e875 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -59,9 +59,6 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createMBQCConversionPass); mlir::registerPass(catalyst::createMemrefCopyToLinalgCopyPass); mlir::registerPass(catalyst::createMemrefToLLVMWithTBAAPass); - mlir::registerPass(catalyst::createMhloLegalizeSortPass); - mlir::registerPass(catalyst::createMhloLegalizeToStdPass); - mlir::registerPass(catalyst::createMhloLegalizeControlFlowPass); mlir::registerPass(catalyst::createMitigationLoweringPass); mlir::registerPass(catalyst::createQnodeToAsyncLoweringPass); mlir::registerPass(catalyst::createQuantumConversionPass); @@ -69,6 +66,9 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); mlir::registerPass(catalyst::createMergeRotationsPass); mlir::registerPass(catalyst::createScatterLoweringPass); + mlir::registerPass(catalyst::createStablehloLegalizeControlFlowPass); + mlir::registerPass(catalyst::createStablehloLegalizeSortPass); + mlir::registerPass(catalyst::createStablehloLegalizeToStdPass); mlir::registerPass(catalyst::createSplitMultipleTapesPass); mlir::registerPass(catalyst::createTestPass); } diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index eb01ba0a20..aaeb6c225d 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -44,10 +44,10 @@ void createHloLoweringPipeline(OpPassManager &pm) pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass(catalyst::createMhloLegalizeControlFlowPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeControlFlowPass()); pm.addNestedPass(mhlo::createLegalizeHloToLinalgPass()); - pm.addNestedPass(catalyst::createMhloLegalizeToStdPass()); - pm.addNestedPass(catalyst::createMhloLegalizeSortPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); + pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); pm.addPass(mlir::mhlo::createConvertToSignlessPass()); pm.addPass(mlir::createCanonicalizerPass()); diff --git a/mlir/lib/stablehlo/CMakeLists.txt b/mlir/lib/stablehlo/CMakeLists.txt index 8ad4f319ba..8018f749ed 100644 --- a/mlir/lib/stablehlo/CMakeLists.txt +++ b/mlir/lib/stablehlo/CMakeLists.txt @@ -1,9 +1,9 @@ -set(LIBRARY_NAME catalyst-mhlo-transforms) +set(LIBRARY_NAME catalyst-stablehlo-transforms) file(GLOB SRC - mhlo_legalize_control_flow.cpp - mhlo_legalize_sort.cpp - mhlo_legalize_to_std.cpp + stablehlo_legalize_control_flow.cpp + stablehlo_legalize_sort.cpp + stablehlo_legalize_to_std.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) @@ -15,8 +15,8 @@ set(LIBS set(DEPENDS MLIRCatalystPassIncGen - MLIRHLOCatalystPassIncGen - MLIRMhloLegalizeToStandardIncGen + STABLEHLOCatalystPassIncGen + MLIRStablehloLegalizeToStandardIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) diff --git a/mlir/lib/stablehlo/mhlo_legalize_control_flow.cpp b/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp similarity index 100% rename from mlir/lib/stablehlo/mhlo_legalize_control_flow.cpp rename to mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp diff --git a/mlir/lib/stablehlo/mhlo_legalize_sort.cpp b/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp similarity index 100% rename from mlir/lib/stablehlo/mhlo_legalize_sort.cpp rename to mlir/lib/stablehlo/stablehlo_legalize_sort.cpp diff --git a/mlir/lib/stablehlo/mhlo_legalize_to_std.cpp b/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp similarity index 100% rename from mlir/lib/stablehlo/mhlo_legalize_to_std.cpp rename to mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index 2b1d16a5ef..46d667cba1 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -26,7 +26,7 @@ set(LIBS MLIROptLib MLIRCatalyst catalyst-transforms - catalyst-mhlo-transforms + catalyst-stablehlo-transforms MLIRQuantum quantum-transforms MLIRQEC diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index bdcdf9a98e..00b7dc7516 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -8,7 +8,7 @@ set(LIBS MLIROptLib MLIRCatalyst catalyst-transforms - catalyst-mhlo-transforms + catalyst-stablehlo-transforms MLIRQuantum quantum-transforms MLIRQEC From ae49138ed050da580ad0ae3f9ebd3e445a9eec02 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 18:43:37 -0400 Subject: [PATCH 21/63] breadcrumbs --- mlir/CMakeLists.txt | 1 + mlir/include/stablehlo/CMakeLists.txt | 2 +- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 7 +- .../Catalyst/Transforms/RegisterAllPasses.cpp | 5 +- mlir/lib/stablehlo/CMakeLists.txt | 7 +- .../stablehlo_legalize_control_flow.cpp | 91 +++++++++---------- .../lib/stablehlo/stablehlo_legalize_sort.cpp | 32 +++---- .../stablehlo/stablehlo_legalize_to_std.cpp | 51 +++++------ 8 files changed, 97 insertions(+), 99 deletions(-) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index c47db99ccc..49b65f90d4 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -214,6 +214,7 @@ add_subdirectory(cmake/modules) ###################### add_subdirectory(test) + unset(LLVM_USE_LINKER) set(STABLEHLO_BUILD_EMBEDDED ON) set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) diff --git a/mlir/include/stablehlo/CMakeLists.txt b/mlir/include/stablehlo/CMakeLists.txt index 78bc778a6d..d643202931 100644 --- a/mlir/include/stablehlo/CMakeLists.txt +++ b/mlir/include/stablehlo/CMakeLists.txt @@ -9,6 +9,6 @@ add_public_tablegen_target(STABLEHLOCatalystPassIncGen) # to build the rewrite patterns for the --stablehlo-legalize-to-std pass set(LLVM_TARGET_DEFINITIONS stablehlo_legalize_to_standard_patterns.td) include_directories( - ${CATALYST_MAIN_INCLUDE_DIR}/../mlir-hlo) + ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) mlir_tablegen(generated_stablehlo_legalize_to_standard.cpp.inc -gen-rewriters) add_public_tablegen_target(MLIRStablehloLegalizeToStandardIncGen) diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index b4776af6a9..0811181565 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -34,10 +34,14 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + StablehloPasses + StablehloOps ) set(DEPENDS MLIRCatalystPassIncGen + StablehloBaseIncGen + StablehloOpsIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) @@ -45,4 +49,5 @@ target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include - ${CMAKE_BINARY_DIR}/include) + ${CMAKE_BINARY_DIR}/include + ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 8c9609e875..85bed5358e 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mlir/Pass/PassRegistry.h" + #include "Catalyst/Transforms/Passes.h" #include "Gradient/Transforms/Passes.h" #include "Ion/Transforms/Passes.h" @@ -19,9 +21,8 @@ #include "Mitigation/Transforms/Passes.h" #include "QEC/Transforms/Passes.h" #include "Quantum/Transforms/Passes.h" +#include "stablehlo/Passes.h" #include "Test/Transforms/Passes.h" -#include "mlir-hlo/Passes.h" -#include "mlir/Pass/PassRegistry.h" void catalyst::registerAllCatalystPasses() { diff --git a/mlir/lib/stablehlo/CMakeLists.txt b/mlir/lib/stablehlo/CMakeLists.txt index 8018f749ed..160e9f4538 100644 --- a/mlir/lib/stablehlo/CMakeLists.txt +++ b/mlir/lib/stablehlo/CMakeLists.txt @@ -11,12 +11,16 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + StablehloPasses + StablehloOps ) set(DEPENDS MLIRCatalystPassIncGen STABLEHLOCatalystPassIncGen MLIRStablehloLegalizeToStandardIncGen + StablehloBaseIncGen + StablehloOpsIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) @@ -24,4 +28,5 @@ target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include - ${CMAKE_BINARY_DIR}/include) + ${CMAKE_BINARY_DIR}/include + ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) diff --git a/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp b/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp index fd9641cb81..fd28262c46 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp +++ b/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp @@ -33,16 +33,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The modifications are porting the pass from the upstream MHLO namespace to +// The modifications are porting the pass from the upstream stablehlo namespace to // catalyst namespace. -// This file implements logic for lowering MHLO dialect to SCF dialect. +// This file implements logic for lowering Stablehlo dialect to SCF dialect. #include #include #include -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project @@ -56,38 +55,34 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -// #include "stablehlo/dialect/StablehloOps.h" -// #include "stablehlo/transforms/Passes.h" -#include "llvm/Support/Casting.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" -#include "mlir-hlo/Passes.h" +#include "stablehlo/Passes.h" using namespace mlir; -using namespace mhlo; -// using namespace stablehlo; +using namespace stablehlo; using namespace catalyst; namespace catalyst { -#define GEN_PASS_DEF_MHLOLEGALIZECONTROLFLOWPASS -#define GEN_PASS_DECL_MHLOLEGALIZECONTROLFLOWPASS -// #define GEN_PASS_DEF_STABLEHLOLEGALIZECONTROLFLOWPASS -// #define GEN_PASS_DECL_STABLEHLOLEGALIZECONTROLFLOWPASS -#include "mlir-hlo/Passes.h.inc" +#define GEN_PASS_DEF_STABLEHLOLEGALIZECONTROLFLOWPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZECONTROLFLOWPASS +#include "stablehlo/Passes.h.inc" } // namespace catalyst namespace { -// All transformations in this file take mhlo blocks which end with +// All transformations in this file take stablehlo blocks which end with // stablehlo::ReturnOp and lower to SCF ops which end with scf::YieldOp. Inline an // entire block with the only change being return -> yield. -void inlineMhloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &mhlo, Region &scf) +void inlineStablehloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &r, Region &scf) { // Remove an existing block, then move the region over. if (!scf.empty()) rewriter.eraseBlock(&scf.back()); - rewriter.inlineRegionBefore(mhlo, scf, scf.end()); + rewriter.inlineRegionBefore(r, scf, scf.end()); // Fix up the terminator. PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToEnd(&scf.back()); @@ -95,7 +90,7 @@ void inlineMhloRegionIntoSCFRegion(PatternRewriter &rewriter, Region &mhlo, Regi rewriter.replaceOpWithNewOp(terminator, terminator->getOperands()); } -// mhlo ops need inputs to be tensors, but scalar values can be a scalar tensor +// stablehlo ops need inputs to be tensors, but scalar values can be a scalar tensor // or a 1 element tensor. To handle this, collapse shape before extracting the // scalar value when necessary. Value extractTensorValue(OpBuilder &b, Value tensor) @@ -116,7 +111,7 @@ struct ScfForBounds { unsigned indexArgIndex; }; -std::optional extractForBounds(mhlo::WhileOp op) +std::optional extractForBounds(stablehlo::WhileOp op) { auto &cond = op.getCond().front(); auto &body = op.getBody().front(); @@ -129,10 +124,10 @@ std::optional extractForBounds(mhlo::WhileOp op) return mlir::cast(v).getArgNumber(); }; - auto compare = llvm::dyn_cast(cond.front()); + auto compare = llvm::dyn_cast(cond.front()); // If the rhs of the comapare is defined outside the block, it's a constant // within the loop. - if (!compare || compare.getComparisonDirection() != mhlo::ComparisonDirection::LT || + if (!compare || compare.getComparisonDirection() != stablehlo::ComparisonDirection::LT || compare.getRhs().getParentBlock() == &cond || !getElementTypeOrSelf(compare.getLhs().getType()).isSignlessIntOrIndex()) { return std::nullopt; @@ -142,7 +137,7 @@ std::optional extractForBounds(mhlo::WhileOp op) if (!iterArg) return std::nullopt; - auto add = llvm::dyn_cast_or_null( + auto add = llvm::dyn_cast_or_null( body.getTerminator()->getOperand(*iterArg).getDefiningOp()); if (!add || matchBbArg(add.getLhs(), body) != iterArg || add.getRhs().getParentBlock() == &body) { @@ -158,10 +153,10 @@ std::optional extractForBounds(mhlo::WhileOp op) } // Rewrites `stablehlo.while` to `scf.while` or `scf.for`. -struct WhileOpPattern : public OpConversionPattern { +struct WhileOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(mhlo::WhileOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(stablehlo::WhileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); @@ -173,8 +168,8 @@ struct WhileOpPattern : public OpConversionPattern { extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); rewriter.setInsertionPointToEnd(newForOp.getBody()); - // Inline while body, and only replace the mhlo.return with an scf.yield. - inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newForOp.getRegion()); + // Inline while body, and only replace the stablehlo.return with an scf.yield. + inlineStablehloRegionIntoSCFRegion(rewriter, op.getBody(), newForOp.getRegion()); auto indexArg = newForOp.getRegion().insertArgument( unsigned{0}, newForOp.getLowerBound().getType(), loc); auto oldIndexArg = newForOp.getRegion().getArgument(1 + bounds->indexArgIndex); @@ -194,39 +189,39 @@ struct WhileOpPattern : public OpConversionPattern { // needs to be extracted and used with an scf.condition. rewriter.inlineRegionBefore(op.getCond(), newWhileOp.getBefore(), newWhileOp.getBefore().end()); - auto conditionReturn = cast(newWhileOp.getBefore().front().getTerminator()); + auto conditionReturn = cast(newWhileOp.getBefore().front().getTerminator()); rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0)); rewriter.replaceOpWithNewOp(conditionReturn, i1, newWhileOp.getBeforeArguments()); - // Inline while body, and only replace the mhlo.return with an scf.yield. - inlineMhloRegionIntoSCFRegion(rewriter, op.getBody(), newWhileOp.getAfter()); + // Inline while body, and only replace the stablehlo.return with an scf.yield. + inlineStablehloRegionIntoSCFRegion(rewriter, op.getBody(), newWhileOp.getAfter()); rewriter.replaceOp(op, newWhileOp.getResults()); return success(); } }; -// Rewrites `mhlo.if` to `scf.if`. -struct IfOpPattern : public OpConversionPattern { +// Rewrites `stablehlo.if` to `scf.if`. +struct IfOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(mhlo::IfOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(stablehlo::IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto scfIf = rewriter.create(op.getLoc(), op.getResultTypes(), extractTensorValue(rewriter, adaptor.getPred()), /*withElseRegion=*/true); - inlineMhloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), scfIf.getThenRegion()); - inlineMhloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), scfIf.getElseRegion()); + inlineStablehloRegionIntoSCFRegion(rewriter, op.getTrueBranch(), scfIf.getThenRegion()); + inlineStablehloRegionIntoSCFRegion(rewriter, op.getFalseBranch(), scfIf.getElseRegion()); rewriter.replaceOp(op, scfIf.getResults()); return success(); } }; -// Rewrites `mhlo.case` to a nested `scf.if`. -struct CaseOpPattern : public OpConversionPattern { +// Rewrites `stablehlo.case` to a nested `scf.if`. +struct CaseOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; // Recursively create if/else ops to handle each possible value in a case op. @@ -243,20 +238,20 @@ struct CaseOpPattern : public OpConversionPattern { auto constAttr = DenseElementsAttr::get( shapedType, {mlir::cast(outerBuilder.getI32IntegerAttr(currentIdx))}); Value currentIdxVal = - outerBuilder.create(loc, idxValue.getType(), constAttr); + outerBuilder.create(loc, idxValue.getType(), constAttr); auto scfIf = outerBuilder.create( loc, op.getResultTypes(), extractTensorValue(outerBuilder, - outerBuilder.create(loc, idxValue, currentIdxVal, + outerBuilder.create(loc, idxValue, currentIdxVal, ComparisonDirection::EQ)), /*withElseRegion=*/true); - inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], + inlineStablehloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], scfIf.getThenRegion()); int nextIdx = currentIdx + 1; // Don't recurse for the final default block. if (currentIdx == static_cast(finalIdx)) { - inlineMhloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], + inlineStablehloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], scfIf.getElseRegion()); } else { @@ -268,14 +263,14 @@ struct CaseOpPattern : public OpConversionPattern { return scfIf; } - LogicalResult matchAndRewrite(mhlo::CaseOp op, OpAdaptor adaptor, + LogicalResult matchAndRewrite(stablehlo::CaseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Inline the op if there is only a default block. if (op.getBranches().size() == 1) { Block &block = op.getBranches().front().front(); auto results = block.getTerminator()->getOperands(); - // Remove the mhlo.return terminator, then inline the block. + // Remove the stablehlo.return terminator, then inline the block. rewriter.eraseOp(block.getTerminator()); rewriter.inlineBlockBefore(/*source=*/&block, /*dest=*/op.getOperation(), /*argValues=*/{}); @@ -289,8 +284,8 @@ struct CaseOpPattern : public OpConversionPattern { } }; -struct MhloLegalizeControlFlowPass - : public catalyst::impl::MhloLegalizeControlFlowPassBase { +struct StablehloLegalizeControlFlowPass + : public catalyst::impl::StablehloLegalizeControlFlowPassBase { // Perform the lowering to MLIR control flow. void runOnOperation() override { @@ -302,7 +297,7 @@ struct MhloLegalizeControlFlowPass mlir::ConversionTarget target(*ctx); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(f, target, std::move(patterns)))) { signalPassFailure(); @@ -312,7 +307,7 @@ struct MhloLegalizeControlFlowPass } // namespace -std::unique_ptr catalyst::createMhloLegalizeControlFlowPass() +std::unique_ptr catalyst::createStablehloLegalizeControlFlowPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp b/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp index 98a87b1726..4418c9d3cf 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp +++ b/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp @@ -33,7 +33,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The modifications are porting the pass from the upstream MHLO namespace to +// The modifications are porting the pass from the upstream stablehlo namespace to // catalyst namespace. // This file implements logic for lowering stablehlo.sort to the SCF dialect. @@ -41,8 +41,6 @@ limitations under the License. #include #include -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -64,23 +62,21 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" -// #include "stablehlo/dialect/StablehloOps.h" -// #include "stablehlo/transforms/Passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" -#include "mlir-hlo/Passes.h" +#include "stablehlo/Passes.h" using namespace mlir; -using namespace mhlo; -// using namespace stablehlo; +using namespace stablehlo; using namespace catalyst; namespace catalyst { -#define GEN_PASS_DEF_MHLOLEGALIZESORTPASS -#define GEN_PASS_DECL_MHLOLEGALIZESORTPASS -// #define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS -// #define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS -#include "mlir-hlo/Passes.h.inc" + +#define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS +#include "stablehlo/Passes.h.inc" } // namespace catalyst @@ -574,8 +570,8 @@ struct SortOpPattern : public OpRewritePattern { } }; -struct MhloLegalizeSortPass - : public catalyst::impl::MhloLegalizeSortPassBase { +struct StablehloLegalizeSortPass + : public catalyst::impl::StablehloLegalizeSortPassBase { // Perform the lowering to MLIR control flow. void runOnOperation() override { @@ -587,7 +583,7 @@ struct MhloLegalizeSortPass mlir::ConversionTarget target(*ctx); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(f, target, std::move(patterns)))) { signalPassFailure(); @@ -597,7 +593,7 @@ struct MhloLegalizeSortPass } // namespace -std::unique_ptr catalyst::createMhloLegalizeSortPass() +std::unique_ptr catalyst::createStablehloLegalizeSortPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp b/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp index 020f39b297..d987e9e2ee 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp +++ b/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp @@ -33,18 +33,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The modifications are porting the pass from the upstream MHLO namespace to +// The modifications are porting the pass from the upstream stablehlo namespace to // catalyst namespace. -// This file implements logic for lowering MHLO dialect to Standard dialect. +// This file implements logic for lowering Stablehlo dialect to Standard dialect. #include #include #include -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/passes.h" -#include "mhlo/transforms/rewriters.h" // (??) +//#include "mhlo/transforms/rewriters.h" // (??) #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -52,34 +50,31 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -// #include "stablehlo/dialect/StablehloOps.h" -// #include "stablehlo/transforms/Passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" -#include "mlir-hlo/Passes.h" +#include "stablehlo/Passes.h" using namespace mlir; -using namespace mhlo; -// using namespace stablehlo; +using namespace stablehlo; using namespace catalyst; namespace catalyst { -#define GEN_PASS_DEF_MHLOLEGALIZETOSTANDARDPASS -#define GEN_PASS_DECL_MHLOLEGALIZETOSTANDARDPASS -// #define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS -// #define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS +#define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS +#define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS #include "mlir-hlo/Passes.h.inc" -#include "mlir-hlo/generated_mhlo_legalize_to_standard.cpp.inc" +#include "mlir-hlo/generated_stablehlo_legalize_to_standard.cpp.inc" } // namespace catalyst namespace { -class CompareIConvert : public OpRewritePattern { +class CompareIConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.getLhs(); auto rhs = op.getRhs(); @@ -124,11 +119,11 @@ class CompareIConvert : public OpRewritePattern { } }; -class CompareFConvert : public OpRewritePattern { +class CompareFConvert : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::CompareOp op, PatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(stablehlo::CompareOp op, PatternRewriter &rewriter) const override { auto lhs = op.getLhs(); auto rhs = op.getRhs(); @@ -177,11 +172,11 @@ class CompareFConvert : public OpRewritePattern { // convert the integer constant to iota result type. For complex types, the real // part is replaced with the generated constant and the imaginary part is // replaced with zero tensor. -class ConvertIotaOp : public OpRewritePattern { +class ConvertIotaOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::IotaOp op, PatternRewriter &rewriter) const override + LogicalResult matchAndRewrite(stablehlo::IotaOp op, PatternRewriter &rewriter) const override { auto outputType = mlir::cast(op.getType()); auto outputSize = outputType.getNumElements(); @@ -233,19 +228,19 @@ class ConvertIotaOp : public OpRewritePattern { auto zeroes = rewriter.create( loc, DenseIntElementsAttr::get(intShapeType, APInt(bitwidth, 0))); auto imagZeroes = rewriter.create(loc, intOrFloatShapeTy, zeroes); - rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); + rewriter.replaceOpWithNewOp(op, iotaConst, imagZeroes); return success(); } }; -void populateMhloToStdPatterns(RewritePatternSet *patterns, mlir::MLIRContext *ctx) +void populateStablehloToStdPatterns(RewritePatternSet *patterns, mlir::MLIRContext *ctx) { populateWithGenerated(*patterns); patterns->add(ctx); } -struct MhloLegalizeToStandardPass - : public catalyst::impl::MhloLegalizeToStandardPassBase { +struct StablehloLegalizeToStandardPass + : public catalyst::impl::StablehloLegalizeToStandardPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -255,14 +250,14 @@ struct MhloLegalizeToStandardPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateMhloToStdPatterns(&patterns, &getContext()); + populateStablehloToStdPatterns(&patterns, &getContext()); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // end anonymous namespace -std::unique_ptr catalyst::createMhloLegalizeToStdPass() +std::unique_ptr catalyst::createStablehloLegalizeToStdPass() { - return std::make_unique(); + return std::make_unique(); } From 44122fcced9ffee0ef8429c49ecd7bb8dca074de Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 19:19:03 -0400 Subject: [PATCH 22/63] `make dialects` succeed --- mlir/CMakeLists.txt | 10 ++++---- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 4 +--- mlir/lib/Driver/Pipelines.cpp | 23 ++++++++++--------- mlir/lib/stablehlo/CMakeLists.txt | 3 +-- .../stablehlo/stablehlo_legalize_to_std.cpp | 4 ++-- 5 files changed, 21 insertions(+), 23 deletions(-) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 49b65f90d4..69dc46763f 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -42,6 +42,11 @@ list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") +unset(LLVM_USE_LINKER) +set(STABLEHLO_BUILD_EMBEDDED ON) +set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) +add_subdirectory(stablehlo) + # Policy CMP0175 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. # Policy CMP0177 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. # TODO: Remove once they (and us) have updated their code to deal with it. @@ -214,8 +219,3 @@ add_subdirectory(cmake/modules) ###################### add_subdirectory(test) - -unset(LLVM_USE_LINKER) -set(STABLEHLO_BUILD_EMBEDDED ON) -set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) -add_subdirectory(stablehlo) diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 0811181565..fa34e7e665 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -43,11 +43,9 @@ set(DEPENDS StablehloBaseIncGen StablehloOpsIncGen ) - add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include - ${CMAKE_BINARY_DIR}/include - ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) + ${CMAKE_BINARY_DIR}/include) diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index aaeb6c225d..9c56c70263 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -12,6 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "stablehlo/transforms/Passes.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + #include "Driver/Pipelines.h" #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" @@ -20,13 +27,7 @@ #include "Mitigation/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/Passes.h" - -#include "stablehlo/transforms/Passes.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" +#include "stablehlo/Passes.h" using namespace mlir; namespace catalyst { @@ -42,13 +43,13 @@ void createHloLoweringPipeline(OpPassManager &pm) { pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); + //pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass(catalyst::createStablehloLegalizeControlFlowPass()); - pm.addNestedPass(mhlo::createLegalizeHloToLinalgPass()); + //(how do I call this??)pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); - pm.addPass(mlir::mhlo::createConvertToSignlessPass()); + pm.addPass(stablehlo::createStablehloConvertToSignlessPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createScatterLoweringPass()); diff --git a/mlir/lib/stablehlo/CMakeLists.txt b/mlir/lib/stablehlo/CMakeLists.txt index 160e9f4538..0a256cfa74 100644 --- a/mlir/lib/stablehlo/CMakeLists.txt +++ b/mlir/lib/stablehlo/CMakeLists.txt @@ -28,5 +28,4 @@ target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include - ${CMAKE_BINARY_DIR}/include - ${CATALYST_MAIN_INCLUDE_DIR}/../stablehlo) + ${CMAKE_BINARY_DIR}/include) diff --git a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp b/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp index d987e9e2ee..9994fd7aa8 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp +++ b/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp @@ -63,8 +63,8 @@ namespace catalyst { #define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS #define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS -#include "mlir-hlo/Passes.h.inc" -#include "mlir-hlo/generated_stablehlo_legalize_to_standard.cpp.inc" +#include "stablehlo/Passes.h.inc" +#include "stablehlo/generated_stablehlo_legalize_to_standard.cpp.inc" } // namespace catalyst From 5126d67f7effbd98c6ab80d3920f03a6d4dc0e9c Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 19:28:40 -0400 Subject: [PATCH 23/63] checkout correct stablehlo commit in CI --- .github/workflows/check-catalyst.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 9bcecde9c0..e01c7eacf1 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -319,7 +319,7 @@ jobs: uses: actions/checkout@v4 with: repository: openxla/stablehlo - ref: f1f035fea33dcfdd7c471eb7f39174b344003117 + ref: 69d6dae46e1c7de36e6e6973654754f05353cba5 path: mlir/stablehlo - name: Build MLIR Dialects From 4ce1ec130b1021364b610249a3dbd129849407b0 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Wed, 30 Jul 2025 20:19:41 -0400 Subject: [PATCH 24/63] clean py pipeline; TODO-fy cpp pipeline --- frontend/catalyst/pipelines.py | 2 -- mlir/lib/Driver/Pipelines.cpp | 8 +++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 6037865e16..9cfaf2a89f 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -227,9 +227,7 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: hlo_lowering = [ "canonicalize", "func.func(chlo-legalize-to-stablehlo)", - #"stablehlo-legalize-to-hlo", "func.func(stablehlo-legalize-control-flow)", - #"func.func(stablehlo-legalize-to-linalg)", "func.func(stablehlo-aggressive-simplification)", "stablehlo-legalize-to-linalg{enable-primitive-ops}", "func.func(stablehlo-legalize-to-std)", diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 9c56c70263..c6dc625f71 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "stablehlo/transforms/Passes.h" +#include + #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "stablehlo/transforms/Passes.h" -#include "Driver/Pipelines.h" #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" +#include "Driver/Pipelines.h" #include "Gradient/IR/GradientDialect.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/Transforms/Passes.h" @@ -44,9 +46,9 @@ void createHloLoweringPipeline(OpPassManager &pm) pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); - //pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass(catalyst::createStablehloLegalizeControlFlowPass()); //(how do I call this??)pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); + //pm.addNestedPass(std::make_unique()); pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); pm.addPass(stablehlo::createStablehloConvertToSignlessPass()); From ebd144a2116f0635fc953af4692f6400f00e4772 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 10:17:05 -0400 Subject: [PATCH 25/63] clang format --- mlir/include/stablehlo/Passes.h | 8 ++++---- .../Catalyst/Transforms/RegisterAllPasses.cpp | 2 +- mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp | 1 - mlir/lib/Driver/Pipelines.cpp | 5 +++-- .../stablehlo_legalize_control_flow.cpp | 16 +++++++++------- mlir/lib/stablehlo/stablehlo_legalize_sort.cpp | 3 +-- mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp | 2 +- 7 files changed, 19 insertions(+), 18 deletions(-) diff --git a/mlir/include/stablehlo/Passes.h b/mlir/include/stablehlo/Passes.h index 260c9f9ba1..d06976f017 100644 --- a/mlir/include/stablehlo/Passes.h +++ b/mlir/include/stablehlo/Passes.h @@ -19,7 +19,7 @@ #include "mlir/Pass/Pass.h" namespace catalyst { - std::unique_ptr createStablehloLegalizeSortPass(); - std::unique_ptr createStablehloLegalizeToStdPass(); - std::unique_ptr createStablehloLegalizeControlFlowPass(); -} +std::unique_ptr createStablehloLegalizeSortPass(); +std::unique_ptr createStablehloLegalizeToStdPass(); +std::unique_ptr createStablehloLegalizeControlFlowPass(); +} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 85bed5358e..03eb06e731 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -21,8 +21,8 @@ #include "Mitigation/Transforms/Passes.h" #include "QEC/Transforms/Passes.h" #include "Quantum/Transforms/Passes.h" -#include "stablehlo/Passes.h" #include "Test/Transforms/Passes.h" +#include "stablehlo/Passes.h" void catalyst::registerAllCatalystPasses() { diff --git a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp index eeb462b399..56312a719b 100644 --- a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp @@ -25,7 +25,6 @@ #include "stablehlo/dialect/StablehloOps.h" - using namespace mlir; namespace catalyst { diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index c6dc625f71..849e55c696 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -47,8 +47,9 @@ void createHloLoweringPipeline(OpPassManager &pm) pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass(catalyst::createStablehloLegalizeControlFlowPass()); - //(how do I call this??)pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); - //pm.addNestedPass(std::make_unique()); + //(how do I call + //this??)pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); + // pm.addNestedPass(std::make_unique()); pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); pm.addPass(stablehlo::createStablehloConvertToSignlessPass()); diff --git a/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp b/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp index fd28262c46..0538e9765c 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp +++ b/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp @@ -41,7 +41,6 @@ limitations under the License. #include #include -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project @@ -57,6 +56,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" +#include "llvm/Support/Casting.h" #include "stablehlo/Passes.h" @@ -189,7 +189,8 @@ struct WhileOpPattern : public OpConversionPattern { // needs to be extracted and used with an scf.condition. rewriter.inlineRegionBefore(op.getCond(), newWhileOp.getBefore(), newWhileOp.getBefore().end()); - auto conditionReturn = cast(newWhileOp.getBefore().front().getTerminator()); + auto conditionReturn = + cast(newWhileOp.getBefore().front().getTerminator()); rewriter.setInsertionPointToEnd(&newWhileOp.getBefore().front()); Value i1 = extractTensorValue(rewriter, conditionReturn->getOperand(0)); rewriter.replaceOpWithNewOp(conditionReturn, i1, @@ -243,16 +244,16 @@ struct CaseOpPattern : public OpConversionPattern { auto scfIf = outerBuilder.create( loc, op.getResultTypes(), extractTensorValue(outerBuilder, - outerBuilder.create(loc, idxValue, currentIdxVal, - ComparisonDirection::EQ)), + outerBuilder.create( + loc, idxValue, currentIdxVal, ComparisonDirection::EQ)), /*withElseRegion=*/true); inlineStablehloRegionIntoSCFRegion(outerBuilder, op.getBranches()[currentIdx], - scfIf.getThenRegion()); + scfIf.getThenRegion()); int nextIdx = currentIdx + 1; // Don't recurse for the final default block. if (currentIdx == static_cast(finalIdx)) { inlineStablehloRegionIntoSCFRegion(outerBuilder, op.getBranches()[nextIdx], - scfIf.getElseRegion()); + scfIf.getElseRegion()); } else { PatternRewriter::InsertionGuard guard(outerBuilder); @@ -285,7 +286,8 @@ struct CaseOpPattern : public OpConversionPattern { }; struct StablehloLegalizeControlFlowPass - : public catalyst::impl::StablehloLegalizeControlFlowPassBase { + : public catalyst::impl::StablehloLegalizeControlFlowPassBase< + StablehloLegalizeControlFlowPass> { // Perform the lowering to MLIR control flow. void runOnOperation() override { diff --git a/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp b/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp index 4418c9d3cf..dc2b296ab6 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp +++ b/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp @@ -61,9 +61,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/STLExtras.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" #include "stablehlo/Passes.h" @@ -73,7 +73,6 @@ using namespace catalyst; namespace catalyst { - #define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS #define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS #include "stablehlo/Passes.h.inc" diff --git a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp b/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp index 9994fd7aa8..bfab0f4673 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp +++ b/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp @@ -42,7 +42,7 @@ limitations under the License. #include #include -//#include "mhlo/transforms/rewriters.h" // (??) +// #include "mhlo/transforms/rewriters.h" // (??) #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" From e900dedd70cd45a25e6955ea2a8068ca715bb04e Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 10:19:56 -0400 Subject: [PATCH 26/63] restore dialects unit tests --- mlir/CMakeLists.txt | 88 ++++++++++++++++++++++----------------------- mlir/Makefile | 2 +- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 69dc46763f..6b38bfd502 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -169,50 +169,50 @@ add_subdirectory(cmake/modules) # Handle unittests when building out-of-tree against an installed version of # LLVM/MLIR (not a build tree). Adapted from `llvm/flang/CMakeLists.txt`. -# set(CATALYST_GTEST_AVAILABLE 0) -# if (TARGET llvm_gtest) -# # Installed gtest, via LLVM_INSTALL_GTEST. Preferred. -# message(STATUS "LLVM GTest found, enabling unittests") -# set(CATALYST_GTEST_AVAILABLE 1) -# else() -# find_package(Threads REQUIRED) -# set(LLVM_THIRD_PARTY_DIR llvm-project/third-party) -# set(UNITTEST_DIR ${LLVM_THIRD_PARTY_DIR}/unittest) -# if (NOT EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) -# set(UNITTEST_DIR ${CMAKE_CURRENT_SOURCE_DIR}/llvm/third-party/unittest) -# endif() -# if (EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) -# add_llvm_library(llvm_gtest -# ${UNITTEST_DIR}/googletest/src/gtest-all.cc -# ${UNITTEST_DIR}/googlemock/src/gmock-all.cc -# LINK_COMPONENTS Support # llvm::raw_ostream -# BUILDTREE_ONLY -# ) -# target_include_directories(llvm_gtest SYSTEM -# PUBLIC -# "${UNITTEST_DIR}/googletest/include" -# "${UNITTEST_DIR}/googlemock/include" -# PRIVATE -# "${UNITTEST_DIR}/googletest" -# "${UNITTEST_DIR}/googlemock" -# ) -# target_link_libraries(llvm_gtest PUBLIC Threads::Threads) -# add_llvm_library(llvm_gtest_main -# ${UNITTEST_DIR}/UnitTestMain/TestMain.cpp -# LINK_LIBS llvm_gtest -# LINK_COMPONENTS Support # llvm::cl -# BUILDTREE_ONLY -# ) -# set(CATALYST_GTEST_AVAILABLE 1) -# else() -# message(WARNING "Skipping unittests since LLVM install does not include \ -# gtest headers and libraries") -# set(CATALYST_GTEST_AVAILABLE 0) -# endif() -# endif() -# if (CATALYST_GTEST_AVAILABLE) -# add_subdirectory(unittests) -# endif() +set(CATALYST_GTEST_AVAILABLE 0) +if (TARGET llvm_gtest) + # Installed gtest, via LLVM_INSTALL_GTEST. Preferred. + message(STATUS "LLVM GTest found, enabling unittests") + set(CATALYST_GTEST_AVAILABLE 1) +else() + find_package(Threads REQUIRED) + set(LLVM_THIRD_PARTY_DIR llvm-project/third-party) + set(UNITTEST_DIR ${LLVM_THIRD_PARTY_DIR}/unittest) + if (NOT EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) + set(UNITTEST_DIR ${CMAKE_CURRENT_SOURCE_DIR}/llvm/third-party/unittest) + endif() + if (EXISTS ${UNITTEST_DIR}/googletest/include/gtest/gtest.h) + add_llvm_library(llvm_gtest + ${UNITTEST_DIR}/googletest/src/gtest-all.cc + ${UNITTEST_DIR}/googlemock/src/gmock-all.cc + LINK_COMPONENTS Support # llvm::raw_ostream + BUILDTREE_ONLY + ) + target_include_directories(llvm_gtest SYSTEM + PUBLIC + "${UNITTEST_DIR}/googletest/include" + "${UNITTEST_DIR}/googlemock/include" + PRIVATE + "${UNITTEST_DIR}/googletest" + "${UNITTEST_DIR}/googlemock" + ) + target_link_libraries(llvm_gtest PUBLIC Threads::Threads) + add_llvm_library(llvm_gtest_main + ${UNITTEST_DIR}/UnitTestMain/TestMain.cpp + LINK_LIBS llvm_gtest + LINK_COMPONENTS Support # llvm::cl + BUILDTREE_ONLY + ) + set(CATALYST_GTEST_AVAILABLE 1) + else() + message(WARNING "Skipping unittests since LLVM install does not include \ + gtest headers and libraries") + set(CATALYST_GTEST_AVAILABLE 0) + endif() +endif() +if (CATALYST_GTEST_AVAILABLE) + add_subdirectory(unittests) +endif() ###################### # End of CIRCT code # diff --git a/mlir/Makefile b/mlir/Makefile index cce8c3203d..09eb2322d2 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -161,7 +161,7 @@ dialects: -DLLVM_ENABLE_ZSTD=$(ENABLE_ZSTD) \ -DCATALYST_ENABLE_WARNINGS=$(STRICT_WARNINGS) - cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects quantum-lsp-server catalyst-cli #check-unit-tests + cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects quantum-lsp-server catalyst-cli check-unit-tests .PHONY: test test: From 4268124c2aad5dc7a4d9d1008f5e5abe3bb8de32 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 11:59:12 -0400 Subject: [PATCH 27/63] burn away unnecessary cmake targets --- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 4 ---- mlir/lib/Driver/CMakeLists.txt | 7 +------ mlir/lib/stablehlo/CMakeLists.txt | 4 ---- mlir/tools/quantum-opt/CMakeLists.txt | 3 --- 4 files changed, 1 insertion(+), 17 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index fa34e7e665..437a00b70f 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -34,14 +34,10 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} - StablehloPasses - StablehloOps ) set(DEPENDS MLIRCatalystPassIncGen - StablehloBaseIncGen - StablehloOpsIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index fa448ed77c..806b585b2b 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -40,10 +40,9 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - StablehloRegister - StablehloOptimizationPasses MLIRCatalystTest ${ENZYME_LIB} + StablehloRegister StablehloCAPI ) @@ -52,10 +51,6 @@ add_mlir_library(CatalystCompilerDriver CatalystLLVMTarget.cpp Pipelines.cpp - DEPENDS - StablehloBaseIncGen - OptimizationPassesIncGen - LINK_LIBS PRIVATE ${LIBS} ) diff --git a/mlir/lib/stablehlo/CMakeLists.txt b/mlir/lib/stablehlo/CMakeLists.txt index 0a256cfa74..8018f749ed 100644 --- a/mlir/lib/stablehlo/CMakeLists.txt +++ b/mlir/lib/stablehlo/CMakeLists.txt @@ -11,16 +11,12 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} - StablehloPasses - StablehloOps ) set(DEPENDS MLIRCatalystPassIncGen STABLEHLOCatalystPassIncGen MLIRStablehloLegalizeToStandardIncGen - StablehloBaseIncGen - StablehloOpsIncGen ) add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 00b7dc7516..2242bd61fa 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -22,9 +22,6 @@ set(LIBS MLIRIon ion-transforms StablehloRegister - StablehloPasses - StablehloOptimizationPasses - StablehloOps StablehloCAPI MLIRCatalystTest MLIRCatalystUtils From 58d1db4a2d1c03a9826ddb22530ed1db0e1805a8 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 12:12:10 -0400 Subject: [PATCH 28/63] finalize py pipeline, everything passes now --- frontend/catalyst/pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 9cfaf2a89f..d7572d562d 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -229,7 +229,7 @@ def get_hlo_lowering_stage(_options: CompileOptions) -> List[str]: "func.func(chlo-legalize-to-stablehlo)", "func.func(stablehlo-legalize-control-flow)", "func.func(stablehlo-aggressive-simplification)", - "stablehlo-legalize-to-linalg{enable-primitive-ops}", + "stablehlo-legalize-to-linalg", "func.func(stablehlo-legalize-to-std)", "func.func(stablehlo-legalize-sort)", "stablehlo-convert-to-signless", From 3f317e8246ba7391423454b7ebad79c733522606 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:01:06 -0400 Subject: [PATCH 29/63] cpp pipeline --- mlir/lib/Driver/Pipelines.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 849e55c696..e5062457c7 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -19,7 +19,9 @@ #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "stablehlo/conversions/linalg/transforms/Passes.h" #include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" @@ -47,20 +49,20 @@ void createHloLoweringPipeline(OpPassManager &pm) pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass(catalyst::createStablehloLegalizeControlFlowPass()); - //(how do I call - //this??)pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); - // pm.addNestedPass(std::make_unique()); + stablehlo::StablehloAggressiveSimplificationPassOptions ASoptions; + pm.addNestedPass( + stablehlo::createStablehloAggressiveSimplificationPass(ASoptions)); + pm.addNestedPass(stablehlo::createStablehloLegalizeToLinalgPass()); pm.addNestedPass(catalyst::createStablehloLegalizeToStdPass()); pm.addNestedPass(catalyst::createStablehloLegalizeSortPass()); pm.addPass(stablehlo::createStablehloConvertToSignlessPass()); - pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(catalyst::createScatterLoweringPass()); pm.addPass(catalyst::createHloCustomCallLoweringPass()); pm.addPass(mlir::createCSEPass()); - mlir::LinalgDetensorizePassOptions options; - options.aggressiveMode = true; - pm.addNestedPass(mlir::createLinalgDetensorizePass(options)); + mlir::LinalgDetensorizePassOptions LDoptions; + LDoptions.aggressiveMode = true; + pm.addNestedPass(mlir::createLinalgDetensorizePass(LDoptions)); pm.addPass(catalyst::createDetensorizeSCFPass()); pm.addPass(mlir::createCanonicalizerPass()); } From 92fcd5f98e7b5f827fd2afff9712f01e585deaec Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:12:56 -0400 Subject: [PATCH 30/63] set CI stablehlo version constant --- .dep-versions | 2 +- .github/workflows/constants.yaml | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.dep-versions b/.dep-versions index eca868887c..8e2b2b90f3 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,7 +2,7 @@ # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.6.2 -mhlo=1dd2e71331014ae0373f6bf900ce6be393357190 +stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5 llvm=f8cb7987c64dcffb72414a40560055cb717dbf74 enzyme=v0.0.186 diff --git a/.github/workflows/constants.yaml b/.github/workflows/constants.yaml index 7e6510517f..52a2ad7dcc 100644 --- a/.github/workflows/constants.yaml +++ b/.github/workflows/constants.yaml @@ -19,9 +19,9 @@ on: llvm_version: description: "LLVM version" value: ${{ jobs.set-constants.outputs.llvm_version }} - mhlo_version: - description: "MHLO version" - value: ${{ jobs.set-constants.outputs.mhlo_version }} + stablehlo_version: + description: "Stablehlo version" + value: ${{ jobs.set-constants.outputs.stablehlo_version }} enzyme_version: description: "Enzyme version" value: ${{ jobs.set-constants.outputs.enzyme_version }} @@ -69,9 +69,9 @@ jobs: id: llvm_version run: echo "llvm_version=$(grep llvm .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_OUTPUT - - name: MHLO version - id: mhlo_version - run: echo "mhlo_version=$(grep mhlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_OUTPUT + - name: Stablehlo version + id: stablehlo_version + run: echo "stablehlo_version=$(grep stablehlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_OUTPUT - name: Enzyme version id: enzyme_version @@ -113,7 +113,7 @@ jobs: outputs: llvm_version: ${{ steps.llvm_version.outputs.llvm_version }} - mhlo_version: ${{ steps.mhlo_version.outputs.mhlo_version }} + stablehlo_version: ${{ steps.stablehlo_version.outputs.stablehlo_version }} enzyme_version: ${{ steps.enzyme_version.outputs.enzyme_version }} python_versions: ${{ steps.python_versions.outputs.python_versions }} python_test_versions: ${{ steps.python_test_versions.outputs.python_test_versions }} From 9cd6bc86822da2db187ed5c71372a9234ed4deff Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:15:43 -0400 Subject: [PATCH 31/63] burn mhlo in check-pl-compat.yaml --- .github/workflows/check-pl-compat.yaml | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml index e26fabbb05..a6f9c7f842 100644 --- a/.github/workflows/check-pl-compat.yaml +++ b/.github/workflows/check-pl-compat.yaml @@ -73,19 +73,6 @@ jobs: path: llvm-build key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-ci-build-gcc fail-on-cache-miss: True - - uses: actions/cache/restore@v4 - if: ${{ inputs.catalyst != 'stable' }} - with: - path: mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source - enableCrossOsArchive: True - fail-on-cache-miss: True - - uses: actions/cache/restore@v4 - if: ${{ inputs.catalyst != 'stable' }} - with: - path: mhlo-build - key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-gcc - fail-on-cache-miss: True - uses: actions/cache/restore@v4 if: ${{ inputs.catalyst != 'stable' }} with: @@ -120,7 +107,6 @@ jobs: ENABLE_LLD=ON \ RT_BUILD_DIR="$(pwd)/runtime-build" \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ - MHLO_BUILD_DIR="$(pwd)/mhlo-build" \ ENZYME_BUILD_DIR="$(pwd)/enzyme-build" \ DIALECTS_BUILD_DIR="$(pwd)/quantum-build" \ ENABLE_OPENQASM=ON \ From 345989565b58b8e23d5bb49a5503cf6fc325864d Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:17:30 -0400 Subject: [PATCH 32/63] track CI stablehlo version variable in main ci --- .github/workflows/check-catalyst.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index e01c7eacf1..fa4311b03b 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -314,12 +314,11 @@ jobs: key: ${{ runner.os }}-ccache-${{ github.run_id }} restore-keys: ${{ runner.os }}-ccache- - # just hard code commit manually, set up this stablehlo dep verisons infra later - name: Clone Stablehlo Submodule uses: actions/checkout@v4 with: repository: openxla/stablehlo - ref: 69d6dae46e1c7de36e6e6973654754f05353cba5 + ref: ${{ needs.constants.outputs.stablehlo_version }} path: mlir/stablehlo - name: Build MLIR Dialects From 937433445fc14d612a6d62a5252e816a2a92a2c2 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:23:27 -0400 Subject: [PATCH 33/63] delete mhlo patches --- mlir/patches/mhlo-remove-shardy.patch | 132 -------------------------- mlir/patches/mhlo-rename-sort.patch | 15 --- 2 files changed, 147 deletions(-) delete mode 100644 mlir/patches/mhlo-remove-shardy.patch delete mode 100644 mlir/patches/mhlo-rename-sort.patch diff --git a/mlir/patches/mhlo-remove-shardy.patch b/mlir/patches/mhlo-remove-shardy.patch deleted file mode 100644 index 32ce71061f..0000000000 --- a/mlir/patches/mhlo-remove-shardy.patch +++ /dev/null @@ -1,132 +0,0 @@ -From 70172e8399383d6c1964d73a2d20cba3c55a3279 Mon Sep 17 00:00:00 2001 -From: paul0403 -Date: Thu, 29 May 2025 10:06:35 -0400 -Subject: [PATCH] remove shardy dependency - ---- - bindings/c/CMakeLists.txt | 1 - - stablehlo_ext/CMakeLists.txt | 1 + - stablehlo_ext/analysis/CMakeLists.txt | 3 ++- - stablehlo_ext/transforms/CMakeLists.txt | 7 ++++++- - stablehlo_ext/transforms/stablehlo_refine_shapes.cpp | 3 --- - tests/lit.cfg.py | 1 + - tools/mlir-hlo-opt/mlir-hlo-opt.cc | 2 -- - 7 files changed, 10 insertions(+), 8 deletions(-) - -diff --git a/bindings/c/CMakeLists.txt b/bindings/c/CMakeLists.txt -index fd2a5c2c..53d916d5 100644 ---- a/bindings/c/CMakeLists.txt -+++ b/bindings/c/CMakeLists.txt -@@ -10,7 +10,6 @@ add_mlir_public_c_api_library(MLIRHLOCAPIDialects - MhloPasses - MhloToArithmeticConversion - MhloToMemrefConversion -- MhloToStandard - MhloToLinalg - MhloToStablehlo - StablehloToMhlo -diff --git a/stablehlo_ext/CMakeLists.txt b/stablehlo_ext/CMakeLists.txt -index 3e55a89d..e8d318f1 100644 ---- a/stablehlo_ext/CMakeLists.txt -+++ b/stablehlo_ext/CMakeLists.txt -@@ -12,5 +12,6 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - -+add_subdirectory(analysis) - add_subdirectory(IR) - add_subdirectory(transforms) -diff --git a/stablehlo_ext/analysis/CMakeLists.txt b/stablehlo_ext/analysis/CMakeLists.txt -index 726d340d..0c0259b8 100644 ---- a/stablehlo_ext/analysis/CMakeLists.txt -+++ b/stablehlo_ext/analysis/CMakeLists.txt -@@ -1,5 +1,6 @@ - add_mlir_library(MhloAnalysis -- shape_component_analysis.cc -+ shape_component_analysis.cpp -+ PARTIAL_SOURCES_INTENDED - - DEPENDS - mlir-headers -diff --git a/stablehlo_ext/transforms/CMakeLists.txt b/stablehlo_ext/transforms/CMakeLists.txt -index ee58f490..2d7cc22c 100644 ---- a/stablehlo_ext/transforms/CMakeLists.txt -+++ b/stablehlo_ext/transforms/CMakeLists.txt -@@ -20,9 +20,14 @@ add_mlir_dialect_library(StablehloExtensionPasses - PARTIAL_SOURCES_INTENDED - chlo_recompose_ops.cpp - chlo_preserve_high_level_ops.cpp -+ sink_constants_to_control_flow.cpp -+ stablehlo_add_quant_dequant_conv.cpp - stablehlo_canonicalize_dynamism.cpp -+ stablehlo_canonicalize_from_hlo_import.cpp -+ stablehlo_legalize_quant_composite.cpp -+ stablehlo_prepare_for_hlo_export.cpp - stablehlo_refine_shapes.cpp -- sdy_refine_shapes.cpp -+ symbolic_shape_optimization.cpp - - DEPENDS - StablehloExtensionPassesIncGen -diff --git a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -index cabd6a9f..2e64b4ed 100644 ---- a/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -+++ b/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp -@@ -34,7 +34,6 @@ limitations under the License. - #include "stablehlo_ext/IR/base.h" - #include "stablehlo_ext/IR/stablehlo_ops.h" - #include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc --#include "stablehlo_ext/transforms/sdy_refine_shapes.h" - - namespace mlir { - namespace stablehlo_ext { -@@ -154,7 +153,6 @@ struct StablehloRefineShapesPass - patterns->add(context); - patterns->add(context); - patterns->add(context); -- populateSdyShapeRefinementPatterns(context, patterns); - }; - - if (failed(stablehlo::refineEntryFunction(*context, func, -@@ -172,7 +170,6 @@ void populateStablehloExtRefineShapesPatterns(RewritePatternSet *patterns, - patterns->add(context); - patterns->add(context); - patterns->add(context); -- populateSdyShapeRefinementPatterns(context, patterns); - } - - } // namespace stablehlo_ext -diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py -index ab20fbb5..6c61aec5 100644 ---- a/tests/lit.cfg.py -+++ b/tests/lit.cfg.py -@@ -32,6 +32,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) - - # suffixes: A list of file extensions to treat as test files. - config.suffixes = ['.mlir'] -+config.excludes = ['sdy_refine_shapes.mlir'] - - # test_source_root: The root path where tests are located. - config.test_source_root = os.path.dirname(__file__) -diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/tools/mlir-hlo-opt/mlir-hlo-opt.cc -index f018cbdc..b4474850 100644 ---- a/tools/mlir-hlo-opt/mlir-hlo-opt.cc -+++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cc -@@ -20,7 +20,6 @@ limitations under the License. - #include "mlir/InitAllExtensions.h" - #include "mlir/InitAllPasses.h" - #include "mlir/Tools/mlir-opt/MlirOptMain.h" --#include "shardy/dialect/sdy/ir/dialect.h" - #include "stablehlo/dialect/Register.h" - #include "stablehlo_ext/transforms/passes.h" - #include "transforms/gpu_passes.h" -@@ -41,6 +40,5 @@ int main(int argc, char** argv) { - registerAllExtensions(registry); - mhlo::registerAllMhloDialects(registry); - stablehlo::registerAllDialects(registry); -- registry.insert(); - return failed(MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); - } --- -2.34.1 - diff --git a/mlir/patches/mhlo-rename-sort.patch b/mlir/patches/mhlo-rename-sort.patch deleted file mode 100644 index c356cc35e3..0000000000 --- a/mlir/patches/mhlo-rename-sort.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff --git a/utils/cycle_detector.cc b/utils/cycle_detector.cc -index e3901ae88..890f39654 100644 ---- a/utils/cycle_detector.cc -+++ b/utils/cycle_detector.cc -@@ -199,8 +199,8 @@ static void backwardDfs(GraphCycles::Rep* r, int32_t n, int32_t lowerBound) { - // Recomputes rank assignments to make them compatible with the edges (producer - // has smaller rank than its consumer) - static void reorder(GraphCycles::Rep* r) { -- sort(r->nodes, &r->deltab); -- sort(r->nodes, &r->deltaf); -+ mlir::sort(r->nodes, &r->deltab); -+ mlir::sort(r->nodes, &r->deltaf); - - // Adds contents of delta lists to list (backwards deltas first). - r->list.clear(); From 8d3e68a50e61e29e277c8ba0da8ed7a328f49e56 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:24:00 -0400 Subject: [PATCH 34/63] remove mhlo from linux arm wheel --- .../workflows/build-wheel-linux-arm64.yaml | 87 ++----------------- 1 file changed, 8 insertions(+), 79 deletions(-) diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 4b83d7ae38..7e1e84b797 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -84,14 +84,6 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-container-source enableCrossOsArchive: True - - name: Cache MHLO Source - id: cache-mhlo-source - uses: actions/cache@v4 - with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source - enableCrossOsArchive: True - - name: Cache Enzyme Source id: cache-enzyme-source uses: actions/cache@v4 @@ -109,26 +101,11 @@ jobs: path: ${{ github.workspace }}/mlir/llvm-project - name: Patch LLVM Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - uses: actions/checkout@v4 - with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: ${{ github.workspace }}/mlir/mlir-hlo - - - name: Patch MHLO Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - run: | - cd $GITHUB_WORKSPACE/mlir/mlir-hlo - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch - - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -151,14 +128,6 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Check MHLO Build Cache - id: cache-mhlo-build - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - lookup-only: True - - name: Check Enzyme Build Cache id: cache-enzyme-build uses: actions/cache/restore@v4 @@ -170,7 +139,6 @@ jobs: - name: Install dependencies if: | steps.cache-llvm-build.outputs.cache-hit != 'true' || - steps.cache-mhlo-build.outputs.cache-hit != 'true' || steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | cat /etc/dnf.conf | sed "s/\[main\]/\[main\]\ntimeout=5/g" > /etc/dnf.conf @@ -207,32 +175,6 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build - - name: Build MHLO Dialect - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - # building with LLD is a strong requirement for mhlo - run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - - cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DLLVM_ENABLE_ZLIB=FORCE_ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=default \ - -DLLVM_ENABLE_LLD=ON - - LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build $GITHUB_WORKSPACE/mhlo-build --target check-mlir-hlo - - - name: Save MHLO Build - id: save-mhlo-build - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - uses: actions/cache/save@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | @@ -305,23 +247,6 @@ jobs: key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.11-wheel-build fail-on-cache-miss: True - - name: Get Cached MHLO Source - id: cache-mhlo-source - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source - enableCrossOsArchive: True - fail-on-cache-miss: True - - - name: Get Cached MHLO Build - id: cache-mhlo-build - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - fail-on-cache-miss: True - - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -360,6 +285,13 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc + - name: Clone Stablehlo Submodule + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo + # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -372,8 +304,6 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DMHLO_DIR="$GITHUB_WORKSPACE/mhlo-build/lib/cmake/mlir-hlo" \ - -DMHLO_BINARY_DIR="$GITHUB_WORKSPACE/mhlo-build/bin" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -396,7 +326,6 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ - MHLO_BUILD_DIR="$GITHUB_WORKSPACE/mhlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ From 03f59f017b96d4ecffc7fb829b5666b4ddfbba1f Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:27:41 -0400 Subject: [PATCH 35/63] minux x86 wheel --- .../workflows/build-wheel-linux-x86_64.yaml | 88 ++----------------- 1 file changed, 8 insertions(+), 80 deletions(-) diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index abd4f069ab..957bca3d79 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -103,14 +103,6 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-container-source enableCrossOsArchive: True - - name: Cache MHLO Source - id: cache-mhlo-source - uses: actions/cache@v4 - with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source - enableCrossOsArchive: True - - name: Cache Enzyme Source id: cache-enzyme-source uses: actions/cache@v4 @@ -128,26 +120,11 @@ jobs: path: ${{ github.workspace }}/mlir/llvm-project - name: Patch LLVM Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - uses: actions/checkout@v4 - with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: ${{ github.workspace }}/mlir/mlir-hlo - - - name: Patch MHLO Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - run: | - cd $GITHUB_WORKSPACE/mlir/mlir-hlo - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch - - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -170,14 +147,6 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Check MHLO Build Cache - id: cache-mhlo-build - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - lookup-only: True - - name: Check Enzyme Build Cache id: cache-enzyme-build uses: actions/cache/restore@v4 @@ -189,7 +158,6 @@ jobs: - name: Install dependencies (AlmaLinux) if: | steps.cache-llvm-build.outputs.cache-hit != 'true' || - steps.cache-mhlo-build.outputs.cache-hit != 'true' || steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | # Reduce wait time for repos not responding @@ -205,7 +173,6 @@ jobs: PYTHON_BINS=$(find /opt/_internal/cpython-${{ matrix.python_version }}.*/bin -maxdepth 1 -type d | tr '\n' ':' | sed 's/:$//') echo $PYTHON_BINS >> $GITHUB_PATH - # LLD is required for MHLO builds. # (Don't forget to add the build directory to PATH in subsequent steps, so # other tools can find it, in particular collect2 invoked by gcc.) - name: Build LLVM / MLIR @@ -230,32 +197,6 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build - - name: Build MHLO Dialect - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - # building with LLD is a strong requirement for mhlo - run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - - cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DLLVM_ENABLE_ZLIB=FORCE_ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=default \ - -DLLVM_ENABLE_LLD=ON - - LIT_FILTER_OUT="chlo_legalize_to_mhlo" cmake --build $GITHUB_WORKSPACE/mhlo-build --target check-mlir-hlo - - - name: Save MHLO Build - id: save-mhlo-build - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - uses: actions/cache/save@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | @@ -328,23 +269,6 @@ jobs: key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.11-wheel-build fail-on-cache-miss: True - - name: Get Cached MHLO Source - id: cache-mhlo-source - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-container-source - enableCrossOsArchive: True - fail-on-cache-miss: True - - - name: Get Cached MHLO Build - id: cache-mhlo-build - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ matrix.container_img }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - fail-on-cache-miss: True - - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -385,6 +309,13 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc + - name: Clone Stablehlo Submodule + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo + # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -397,8 +328,6 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DMHLO_DIR="$GITHUB_WORKSPACE/mhlo-build/lib/cmake/mlir-hlo" \ - -DMHLO_BINARY_DIR="$GITHUB_WORKSPACE/mhlo-build/bin" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -421,7 +350,6 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ - MHLO_BUILD_DIR="$GITHUB_WORKSPACE/mhlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ From 81ff5e3fb758aad37fbc29796f04c17ac4b87992 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:30:13 -0400 Subject: [PATCH 36/63] mac arm wheel --- .../workflows/build-wheel-macos-arm64.yaml | 85 ++----------------- 1 file changed, 8 insertions(+), 77 deletions(-) diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index c1fc0b29a0..6031c5c09f 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -89,14 +89,6 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-default-source enableCrossOsArchive: True - - name: Cache MHLO Source - id: cache-mhlo-source - uses: actions/cache@v4 - with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source - enableCrossOsArchive: True - - name: Cache Enzyme Source id: cache-enzyme-source uses: actions/cache@v4 @@ -114,26 +106,11 @@ jobs: path: ${{ github.workspace }}/mlir/llvm-project - name: Patch LLVM Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + if: steps.cache-llvm-source.outputs.cache-hit != 'true' run: | cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch - - name: Clone MHLO Submodule - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - uses: actions/checkout@v4 - with: - repository: tensorflow/mlir-hlo - ref: ${{ needs.constants.outputs.mhlo_version }} - path: ${{ github.workspace }}/mlir/mlir-hlo - - - name: Patch MHLO Source - if: steps.cache-mhlo-source.outputs.cache-hit != 'true' - run: | - cd $GITHUB_WORKSPACE/mlir/mlir-hlo - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch - git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch - - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -156,14 +133,6 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Check MHLO Build Cache - id: cache-mhlo-build - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - lookup-only: True - - name: Check Enzyme Build Cache id: cache-enzyme-build uses: actions/cache/restore@v4 @@ -200,31 +169,6 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build - - name: Build MHLO Dialect - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - - cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ - -DLLVM_ENABLE_LLD=OFF \ - -DLLVM_ENABLE_ZLIB=FORCE_ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DCMAKE_CXX_VISIBILITY_PRESET=default - - cmake --build $GITHUB_WORKSPACE/mhlo-build --target check-mlir-hlo - - - name: Save MHLO Build - id: save-mhlo-build - if: steps.cache-mhlo-build.outputs.cache-hit != 'true' - uses: actions/cache/save@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | @@ -291,23 +235,6 @@ jobs: key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ needs.constants.outputs.primary_python_version }}-wheel-build fail-on-cache-miss: True - - name: Get Cached MHLO Source - id: cache-mhlo-source - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mlir/mlir-hlo - key: mhlo-${{ needs.constants.outputs.mhlo_version }}-default-source - enableCrossOsArchive: True - fail-on-cache-miss: True - - - name: Get Cached MHLO Build - id: cache-mhlo-build - uses: actions/cache/restore@v4 - with: - path: ${{ github.workspace }}/mhlo-build - key: ${{ runner.os }}-${{ runner.arch }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-wheel-build - fail-on-cache-miss: True - - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -360,6 +287,13 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc + - name: Clone Stablehlo Submodule + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo + # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -370,8 +304,6 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DMHLO_DIR="$GITHUB_WORKSPACE/mhlo-build/lib/cmake/mlir-hlo" \ - -DMHLO_BINARY_DIR="$GITHUB_WORKSPACE/mhlo-build/bin" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -393,7 +325,6 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ - MHLO_BUILD_DIR="$GITHUB_WORKSPACE/mhlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ From 9760aecf293b04d410ec5e011db65f0ef89775c9 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 31 Jul 2025 13:36:38 -0400 Subject: [PATCH 37/63] clean mhlo from git grep --- .github/workflows/check-jax-release.yaml | 14 +++++--------- .github/workflows/set_dep_versions.py | 2 +- Makefile | 11 ++--------- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/.github/workflows/check-jax-release.yaml b/.github/workflows/check-jax-release.yaml index 7e181478c7..c6651d2f23 100644 --- a/.github/workflows/check-jax-release.yaml +++ b/.github/workflows/check-jax-release.yaml @@ -38,7 +38,7 @@ jobs: - name: Re-read versions run: | echo "LLVM_REVISION=$(grep llvm .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV - echo "MHLO_REVISION=$(grep mhlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV + echo "STABLEHLO_REVISION=$(grep stablehlo .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV echo "ENZYME_REVISION=$(grep enzyme .dep-versions | awk -F '=' '{ print $2 }')" >> $GITHUB_ENV - name: Clone LLVM repo @@ -48,12 +48,12 @@ jobs: ref: ${{ env.LLVM_REVISION }} path: mlir/llvm-project - - name: Clone MHLO repo + - name: Clone Stablehlo repo uses: actions/checkout@v4 with: - repository: tensorflow/mlir-hlo - ref: ${{ env.MHLO_REVISION }} - path: mlir/mlir-hlo + repository: openxla/stablehlo + ref: ${{ env.STABLEHLO_REVISION }} + path: mlir/stablehlo - name: Clone Enzyme repo uses: actions/checkout@v4 @@ -70,10 +70,6 @@ jobs: run: | make llvm - - name: Build MHLO - run: | - make mhlo - - name: Build Enzyme run: | make enzyme diff --git a/.github/workflows/set_dep_versions.py b/.github/workflows/set_dep_versions.py index ce9a998c18..d2340d4ea4 100644 --- a/.github/workflows/set_dep_versions.py +++ b/.github/workflows/set_dep_versions.py @@ -71,7 +71,7 @@ # Update each version using sed cmds = [ f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}", - f"sed -i '' 's/^mhlo=.*/mhlo={hlo_commit}/' {dep_versions_path}", + f"sed -i '' 's/^stablehlo=.*/stablehlo={hlo_commit}/' {dep_versions_path}", f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}", # Update jaxlib version in __init__.py rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}", diff --git a/Makefile b/Makefile index 68d81ae519..8103554f3e 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,6 @@ BLACKVERSIONMINOR := $(if $(BLACKVERSIONMINOR),$(BLACKVERSIONMINOR),0) MK_ABSPATH := $(abspath $(lastword $(MAKEFILE_LIST))) MK_DIR := $(dir $(MK_ABSPATH)) LLVM_BUILD_DIR ?= $(MK_DIR)/mlir/llvm-project/build -MHLO_BUILD_DIR ?= $(MK_DIR)/mlir/mlir-hlo/bazel-build DIALECTS_SRC_DIR ?= $(MK_DIR)/mlir DIALECTS_BUILD_DIR ?= $(MK_DIR)/mlir/build RT_BUILD_DIR ?= $(MK_DIR)/runtime/build @@ -119,16 +118,13 @@ frontend: $(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG) rm -r frontend/pennylane_catalyst.egg-info -.PHONY: mlir llvm mhlo enzyme dialects runtime oqc +.PHONY: mlir llvm enzyme dialects runtime oqc mlir: $(MAKE) -C mlir all llvm: $(MAKE) -C mlir llvm -mhlo: - $(MAKE) -C mlir mhlo - enzyme: $(MAKE) -C mlir enzyme @@ -269,7 +265,7 @@ clean: clean-all: clean clean-mlir clean-runtime clean-oqc clean-catalyst: clean clean-dialects clean-runtime clean-oqc -.PHONY: clean-mlir clean-dialects clean-plugin clean-llvm clean-mhlo clean-enzyme +.PHONY: clean-mlir clean-dialects clean-plugin clean-llvm clean-enzyme clean-mlir: $(MAKE) -C mlir clean @@ -285,9 +281,6 @@ clean-llvm: reset-llvm: $(MAKE) -C mlir reset-llvm -clean-mhlo: - $(MAKE) -C mlir clean-mhlo - clean-enzyme: $(MAKE) -C mlir clean-enzyme From e99b84299e8448a1cfb46f63f6811e38f3fd1ed0 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 10:16:03 -0400 Subject: [PATCH 38/63] build stablehlo as standalone instead of embedded (cmake) --- mlir/CMakeLists.txt | 47 +++++++++++++++++-- mlir/cmake/modules/CMakeLists.txt | 2 +- .../Catalyst/Transforms/CMakeLists.txt | 4 +- mlir/include/stablehlo/CMakeLists.txt | 4 +- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 1 + mlir/lib/Driver/CMakeLists.txt | 4 +- mlir/lib/stablehlo/CMakeLists.txt | 1 + mlir/tools/catalyst-cli/CMakeLists.txt | 2 +- mlir/tools/quantum-lsp-server/CMakeLists.txt | 2 +- mlir/tools/quantum-opt/CMakeLists.txt | 3 +- 10 files changed, 57 insertions(+), 13 deletions(-) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 6b38bfd502..8a92dd9251 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -42,10 +42,49 @@ list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") -unset(LLVM_USE_LINKER) -set(STABLEHLO_BUILD_EMBEDDED ON) -set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) -add_subdirectory(stablehlo) + +# Discover stablehlo libraries and bundle them into a cmake target for catalyst to link to +# This is because stablehlo does not have a Config.cmake +# so we cannot use find_package(stablehlo) and have to do this manually +set(STABLEHLO_LIBS + ChloCAPI + ChloOps + StablehloAssemblyFormat + StablehloBase + StablehloBroadcastUtils + StablehloCAPI + StablehloLinalgTransforms + StablehloOps + StablehloOptimizationPasses + StablehloPasses + StablehloPassUtils + StablehloRegister + StablehloTypeConversion + StablehloTypeInference + Version + VhloCAPI + VhloOps + VhloTypes +) + +set(STABLEHLO_LIBS_DIR ${PROJECT_SOURCE_DIR}/stablehlo/build/lib) +foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) + add_library(${STABLEHLO_LIB} STATIC IMPORTED GLOBAL) + set_property(TARGET ${STABLEHLO_LIB} PROPERTY + IMPORTED_LOCATION "${STABLEHLO_LIBS_DIR}/lib${STABLEHLO_LIB}.a" + ) +endforeach() + +add_library(ExternalStablehloLib INTERFACE) + +foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) + target_link_libraries(ExternalStablehloLib INTERFACE ${STABLEHLO_LIB}) +endforeach() + +target_include_directories(ExternalStablehloLib INTERFACE + ${PROJECT_SOURCE_DIR}/stablehlo + ${PROJECT_SOURCE_DIR}/stablehlo/build +) # Policy CMP0175 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. # Policy CMP0177 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt index 835aea8ced..b262fb0b01 100644 --- a/mlir/cmake/modules/CMakeLists.txt +++ b/mlir/cmake/modules/CMakeLists.txt @@ -28,7 +28,7 @@ set(llvm_cmake_builddir "${LLVM_BINARY_DIR}/${LLVM_INSTALL_PACKAGE_DIR}") get_property(MLIR_EXPORTS GLOBAL PROPERTY MLIR_EXPORTS) set(TARGETS_TO_REMOVE nlohmann_json tomlplusplus_tomlplusplus ion-transforms CatalystCompilerDriver QECUtils QuantumCAPI qec-transforms) list(REMOVE_ITEM MLIR_EXPORTS ${TARGETS_TO_REMOVE}) -export(TARGETS ${MLIR_EXPORTS} FILE ${catalyst_cmake_builddir}/CatalystTargets.cmake) +export(TARGETS ${MLIR_EXPORTS} ExternalStablehloLib FILE ${catalyst_cmake_builddir}/CatalystTargets.cmake) # Generate MlirConfig.cmake for the build tree. set(CATALYST_CONFIG_CMAKE_DIR "${catalyst_cmake_builddir}") diff --git a/mlir/include/Catalyst/Transforms/CMakeLists.txt b/mlir/include/Catalyst/Transforms/CMakeLists.txt index 52802b77fc..68deb5bc72 100644 --- a/mlir/include/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/include/Catalyst/Transforms/CMakeLists.txt @@ -1,4 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name catalyst) -add_public_tablegen_target(MLIRCatalystPassIncGen) +add_public_tablegen_target(MLIRCatalystPassIncGen + DEPENDS ExternalStablehloLib +) add_mlir_doc(Passes CatalystPasses ./ -gen-pass-doc) diff --git a/mlir/include/stablehlo/CMakeLists.txt b/mlir/include/stablehlo/CMakeLists.txt index d643202931..d7c71deda7 100644 --- a/mlir/include/stablehlo/CMakeLists.txt +++ b/mlir/include/stablehlo/CMakeLists.txt @@ -1,6 +1,8 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name stablehlo) -add_public_tablegen_target(STABLEHLOCatalystPassIncGen) +add_public_tablegen_target(STABLEHLOCatalystPassIncGen + DEPENDS ExternalStablehloLib +) # The following is modified from the # tensorflow/mlir-hlo diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 437a00b70f..a5342e3f0a 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -45,3 +45,4 @@ target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include ${CMAKE_BINARY_DIR}/include) +target_link_libraries(${LIBRARY_NAME} PRIVATE ExternalStablehloLib) diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 806b585b2b..43ec38a7c1 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -42,8 +42,6 @@ set(LIBS ion-transforms MLIRCatalystTest ${ENZYME_LIB} - StablehloRegister - StablehloCAPI ) add_mlir_library(CatalystCompilerDriver @@ -54,3 +52,5 @@ add_mlir_library(CatalystCompilerDriver LINK_LIBS PRIVATE ${LIBS} ) + +target_link_libraries(CatalystCompilerDriver PRIVATE ExternalStablehloLib) diff --git a/mlir/lib/stablehlo/CMakeLists.txt b/mlir/lib/stablehlo/CMakeLists.txt index 8018f749ed..4e035f0114 100644 --- a/mlir/lib/stablehlo/CMakeLists.txt +++ b/mlir/lib/stablehlo/CMakeLists.txt @@ -25,3 +25,4 @@ target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include ${CMAKE_BINARY_DIR}/include) +target_link_libraries(${LIBRARY_NAME} PRIVATE ExternalStablehloLib) diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index 46d667cba1..93b204347d 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -39,7 +39,6 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - StablehloRegister MLIRCatalystTest ${ENZYME_LIB} CatalystCompilerDriver @@ -47,6 +46,7 @@ set(LIBS add_mlir_tool(catalyst-cli catalyst-cli.cpp SUPPORT_PLUGINS) target_link_libraries(catalyst-cli PRIVATE ${LIBS}) +target_link_libraries(catalyst-cli PRIVATE ExternalStablehloLib) llvm_update_compile_flags(catalyst-cli) mlir_check_all_link_libraries(catalyst-cli) export_executable_symbols_for_plugins(catalyst-cli) diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index d0438c26c1..8809649d62 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -11,10 +11,10 @@ set(LIBS MLIRMBQC MLIRMitigation MLIRIon - StablehloRegister ) add_llvm_executable(quantum-lsp-server quantum-lsp-server.cpp) target_link_libraries(quantum-lsp-server PRIVATE ${LIBS}) +target_link_libraries(quantum-lsp-server PRIVATE ExternalStablehloLib) llvm_update_compile_flags(quantum-lsp-server) mlir_check_all_link_libraries(quantum-lsp-server) diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 2242bd61fa..922d4b71da 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -21,8 +21,6 @@ set(LIBS mitigation-transforms MLIRIon ion-transforms - StablehloRegister - StablehloCAPI MLIRCatalystTest MLIRCatalystUtils MLIRTestDialect @@ -30,6 +28,7 @@ set(LIBS add_mlir_tool(quantum-opt quantum-opt.cpp DEPENDS ${LIBS} SUPPORT_PLUGINS) target_link_libraries(quantum-opt PRIVATE ${LIBS}) +target_link_libraries(quantum-opt PRIVATE ExternalStablehloLib) llvm_update_compile_flags(quantum-opt) mlir_check_all_link_libraries(quantum-opt) export_executable_symbols_for_plugins(quantum-opt) From accd284e4e05418f612ff95dffc5f855666d19b4 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 10:43:08 -0400 Subject: [PATCH 39/63] makefile --- Makefile | 11 +++++++-- mlir/CMakeLists.txt | 5 ++-- mlir/Makefile | 35 ++++++++++++++++++++++++--- mlir/test/frontend/lit.site.cfg.py.in | 1 - 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index 8103554f3e..efcba75aee 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ BLACKVERSIONMINOR := $(if $(BLACKVERSIONMINOR),$(BLACKVERSIONMINOR),0) MK_ABSPATH := $(abspath $(lastword $(MAKEFILE_LIST))) MK_DIR := $(dir $(MK_ABSPATH)) LLVM_BUILD_DIR ?= $(MK_DIR)/mlir/llvm-project/build +STABLEHLO_BUILD_DIR ?= $(MK_DIR)/mlir/stablehlo/build DIALECTS_SRC_DIR ?= $(MK_DIR)/mlir DIALECTS_BUILD_DIR ?= $(MK_DIR)/mlir/build RT_BUILD_DIR ?= $(MK_DIR)/runtime/build @@ -118,13 +119,16 @@ frontend: $(PYTHON) -m pip install -e . --extra-index-url https://test.pypi.org/simple $(PIP_VERBOSE_FLAG) rm -r frontend/pennylane_catalyst.egg-info -.PHONY: mlir llvm enzyme dialects runtime oqc +.PHONY: mlir llvm stablehlo enzyme dialects runtime oqc mlir: $(MAKE) -C mlir all llvm: $(MAKE) -C mlir llvm +stablehlo: + $(MAKE) -C mlir stablehlo + enzyme: $(MAKE) -C mlir enzyme @@ -265,7 +269,7 @@ clean: clean-all: clean clean-mlir clean-runtime clean-oqc clean-catalyst: clean clean-dialects clean-runtime clean-oqc -.PHONY: clean-mlir clean-dialects clean-plugin clean-llvm clean-enzyme +.PHONY: clean-mlir clean-dialects clean-plugin clean-llvm clean-stablehlo clean-enzyme clean-mlir: $(MAKE) -C mlir clean @@ -281,6 +285,9 @@ clean-llvm: reset-llvm: $(MAKE) -C mlir reset-llvm +clean-stablehlo: + $(MAKE) -C mlir clean-stablehlo + clean-enzyme: $(MAKE) -C mlir clean-enzyme diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 8a92dd9251..6e1a59a296 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -67,7 +67,6 @@ set(STABLEHLO_LIBS VhloTypes ) -set(STABLEHLO_LIBS_DIR ${PROJECT_SOURCE_DIR}/stablehlo/build/lib) foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) add_library(${STABLEHLO_LIB} STATIC IMPORTED GLOBAL) set_property(TARGET ${STABLEHLO_LIB} PROPERTY @@ -82,8 +81,8 @@ foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) endforeach() target_include_directories(ExternalStablehloLib INTERFACE - ${PROJECT_SOURCE_DIR}/stablehlo - ${PROJECT_SOURCE_DIR}/stablehlo/build + ${STABLEHLO_DIR} + ${STABLEHLO_DIR}/build # for the generated .inc files ) # Policy CMP0175 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. diff --git a/mlir/Makefile b/mlir/Makefile index 09eb2322d2..8a64cf1d5c 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -7,6 +7,7 @@ MK_ABSPATH := $(abspath $(lastword $(MAKEFILE_LIST))) MK_DIR := $(dir $(MK_ABSPATH)) DIALECTS_BUILD_DIR ?= $(MK_DIR)/build LLVM_BUILD_DIR ?= $(MK_DIR)/llvm-project/build +STABLEHLO_BUILD_DIR ?= $(MK_DIR)/stablehlo/build ENZYME_BUILD_DIR ?= $(MK_DIR)/Enzyme/build RT_BUILD_DIR ?= $(MK_DIR)/../runtime/build ENABLE_ASAN ?= OFF @@ -42,6 +43,7 @@ help: @echo "Please use \`make ' where is one of" @echo " all to build MLIR, MLIR-HLO and custom Catalyst dialects" @echo " llvm to build MLIR enabling Python bindings" + @echo " stablehlo to build stablehlo" @echo " enzyme to build Enzyme" @echo " dialects to build custom Catalyst MLIR dialects" @echo " test to run the Catalyst MLIR dialects test suite" @@ -50,7 +52,7 @@ help: @echo " format [version=?] to apply C++ formatter; use with 'version={version}' to run clang-format-{version} instead of clang-format" .PHONY: all -all: llvm enzyme dialects plugin +all: llvm stablehlo enzyme dialects plugin .PHONY: llvm llvm: @@ -90,6 +92,26 @@ llvm: # test to reduce unnecessary dependencies. LIT_FILTER_OUT="Bytecode|tosa-to-tensor|execution_engine" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS) +.PHONY: stablehlo +stablehlo: + @echo "build stablehlo" + + cmake -G Ninja -S stablehlo -B $(STABLEHLO_BUILD_DIR) \ + -DSTABLEHLO_ENABLE_LLD=ON \ + -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ + -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ + -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_C_COMPILER=$(C_COMPILER) \ + -DCMAKE_CXX_COMPILER=$(CXX_COMPILER) \ + -DCMAKE_C_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ + -DCMAKE_CXX_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ + -DCMAKE_EXE_LINKER_FLAGS=$(USE_SANITIZER_FLAGS) \ + -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) + + cmake --build $(STABLEHLO_BUILD_DIR) .PHONY: enzyme enzyme: TARGET_FILE := $(MK_DIR)/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -149,6 +171,8 @@ dialects: -DEnzyme_DIR=$(ENZYME_BUILD_DIR) \ -DENZYME_SRC_DIR=$(MK_DIR)/Enzyme \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ + -DSTABLEHLO_DIR=$(MK_DIR)/stablehlo \ + -DSTABLEHLO_LIBS_DIR=$(STABLEHLO_BUILD_DIR)/lib \ -DRUNTIME_LIB_DIR=$(RT_BUILD_DIR)/lib \ -DMLIR_LIB_DIR=$(LLVM_BUILD_DIR)/lib \ -DCMAKE_C_COMPILER=$(C_COMPILER) \ @@ -168,8 +192,8 @@ test: @echo "test the Catalyst MLIR dialects test suite" cmake --build $(DIALECTS_BUILD_DIR) --target check-dialects -.PHONY: clean clean-dialects clean-enzyme clean-plugin -clean: clean-dialects clean-llvm clean-enzyme clean-plugin +.PHONY: clean clean-dialects clean-enzyme clean-stablehlo clean-plugin +clean: clean-dialects clean-llvm clean-stablehlo clean-enzyme clean-plugin clean-dialects: @echo "clean catalyst dialect build files" @@ -180,6 +204,11 @@ clean-llvm: rm -rf $(LLVM_BUILD_DIR) cd llvm-project; git clean -fd; git checkout . +clean-stablehlo: + @echo "clean Stablehlo dialect build files" + rm -rf $(STABLEHLO_BUILD_DIR) + cd stablehlo; git clean -fd; git checkout . + reset-llvm: @echo "reset llvm git state to the commit tracked in .dep-versions without deleting llvm builds" cd llvm-project; git clean -fd; git checkout . diff --git a/mlir/test/frontend/lit.site.cfg.py.in b/mlir/test/frontend/lit.site.cfg.py.in index 2a6816dbfd..e888797f16 100644 --- a/mlir/test/frontend/lit.site.cfg.py.in +++ b/mlir/test/frontend/lit.site.cfg.py.in @@ -5,7 +5,6 @@ config.python_executable = "@Python3_EXECUTABLE@" config.frontend_test_dir = "@CMAKE_BINARY_DIR@" + "/test/frontend" config.quantum_bin_dir = "@CMAKE_BINARY_DIR@" + "/bin" config.mlir_bindings_dir = "@CMAKE_BINARY_DIR@" + "/python_packages/quantum" -config.mhlo_bin_dir = "@MHLO_BINARY_DIR@" config.lrt_lib_dir = "@RUNTIME_LIB_DIR@" config.mlir_lib_dir = "@MLIR_LIB_DIR@" From 7af47c2145283d40d4fa253f0c6634e5db0d8812 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 10:55:53 -0400 Subject: [PATCH 40/63] regular CI --- .github/workflows/check-catalyst.yaml | 101 ++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 8 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index fa4311b03b..7511e8c10b 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -172,6 +172,80 @@ jobs: COMPILER_LAUNCHER="" \ make llvm + stablehlo: + name: Stablehlo Dialect Build + needs: [constants, llvm, determine_runner] + runs-on: ${{ needs.determine_runner.outputs.runner_group }} + strategy: + matrix: + compiler: ${{ fromJson(needs.constants.outputs.compilers) }} + + steps: + - name: Checkout Catalyst repo + uses: actions/checkout@v4 + + - name: Set up Python # Ensure the "primary" python version is used + uses: actions/setup-python@v5 + with: + python-version: ${{ needs.constants.outputs.primary_python_version }} + + - name: Cache Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache@v4 + with: + path: mlir/mlir-hlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source + enableCrossOsArchive: true + + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo + + - name: Cache Stablehlo Build + id: cache-stablehlo + uses: actions/cache@v4 + with: + path: stablehlo-build + key: ${{ runner.os }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-ci-build-${{ matrix.compiler }}-0 + + - name: Get Cached LLVM Source + id: cache-llvm-source + if: steps.cache-stablehlo.outputs.cache-hit != 'true' + uses: actions/cache@v4 + with: + path: mlir/llvm-project + key: llvm-${{ needs.constants.outputs.llvm_version }}-default-source + enableCrossOsArchive: true + fail-on-cache-miss: true + + - name: Get Cached LLVM Build + id: cache-llvm-build + if: steps.cache-stablehlo.outputs.cache-hit != 'true' + uses: actions/cache@v4 + with: + path: llvm-build + key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-ci-build-${{ matrix.compiler }} + fail-on-cache-miss: true + + - name: Install Deps + if: steps.cache-stablehlo.outputs.cache-hit != 'true' + run: | + sudo apt-get update + sudo apt-get install -y cmake ninja-build clang lld + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo.outputs.cache-hit != 'true' + run: | + C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ + CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$(pwd)/stablehlo-build" \ + COMPILER_LAUNCHER="" \ + make stablehlo + enzyme: name: Enzyme Build needs: [constants, llvm, determine_runner] @@ -249,7 +323,7 @@ jobs: quantum: name: Quantum Dialects Build - needs: [constants, llvm, enzyme, determine_runner] + needs: [constants, llvm, stablehlo, enzyme, determine_runner] runs-on: ${{ needs.determine_runner.outputs.runner_group }} strategy: matrix: @@ -288,6 +362,23 @@ jobs: key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-ci-build-${{ matrix.compiler }} fail-on-cache-miss: true + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache/restore@v4 + with: + path: mlir/mlir-hlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source + enableCrossOsArchive: true + fail-on-cache-miss: true + + - name: Get Cached Stablehlo Build + id: cache-stablehlo + uses: actions/cache/restore@v4 + with: + path: stablehlo-build + key: ${{ runner.os }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-ci-build-${{ matrix.compiler }}-0 + fail-on-cache-miss: true + - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -314,19 +405,13 @@ jobs: key: ${{ runner.os }}-ccache-${{ github.run_id }} restore-keys: ${{ runner.os }}-ccache- - - name: Clone Stablehlo Submodule - uses: actions/checkout@v4 - with: - repository: openxla/stablehlo - ref: ${{ needs.constants.outputs.stablehlo_version }} - path: mlir/stablehlo - - name: Build MLIR Dialects run: | CCACHE_DIR="$(pwd)/.ccache" \ C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$(pwd)/stablehlo-build" \ ENZYME_BUILD_DIR="$(pwd)/enzyme-build" \ DIALECTS_BUILD_DIR="$(pwd)/quantum-build" \ make dialects From ae17b0fa2bdd3815b0d5b1b10470aabb462bf5fb Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 11:24:58 -0400 Subject: [PATCH 41/63] misc --- .github/workflows/check-catalyst.yaml | 4 ++-- mlir/Makefile | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 7511e8c10b..6ef0252f4f 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -193,7 +193,7 @@ jobs: id: cache-stablehlo-source uses: actions/cache@v4 with: - path: mlir/mlir-hlo + path: mlir/stablehlo key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source enableCrossOsArchive: true @@ -366,7 +366,7 @@ jobs: id: cache-stablehlo-source uses: actions/cache/restore@v4 with: - path: mlir/mlir-hlo + path: mlir/stablehlo key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source enableCrossOsArchive: true fail-on-cache-miss: true diff --git a/mlir/Makefile b/mlir/Makefile index 8a64cf1d5c..8aa6f8c91c 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -41,7 +41,7 @@ LLVM_TARGETS ?= check-mlir llvm-symbolizer .PHONY: help help: @echo "Please use \`make ' where is one of" - @echo " all to build MLIR, MLIR-HLO and custom Catalyst dialects" + @echo " all to build MLIR, Stablehlo and custom Catalyst dialects" @echo " llvm to build MLIR enabling Python bindings" @echo " stablehlo to build stablehlo" @echo " enzyme to build Enzyme" From c0858930f973dc98b14dd8cd021176acd28c90e8 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 11:25:11 -0400 Subject: [PATCH 42/63] wheel CI --- .../workflows/build-wheel-linux-arm64.yaml | 76 +++++++++++++++++-- .../workflows/build-wheel-linux-x86_64.yaml | 76 +++++++++++++++++-- .../workflows/build-wheel-macos-arm64.yaml | 75 ++++++++++++++++-- 3 files changed, 206 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 7e1e84b797..fd90befbcf 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -84,6 +84,14 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-container-source enableCrossOsArchive: True + - name: Cache Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache@v4 + with: + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source + enableCrossOsArchive: True + - name: Cache Enzyme Source id: cache-enzyme-source uses: actions/cache@v4 @@ -106,6 +114,14 @@ jobs: cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo + - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -128,6 +144,14 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build + - name: Check Stablehlo Build Cache + id: cache-stablehlo-build + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + lookup-only: True + - name: Check Enzyme Build Cache id: cache-enzyme-build uses: actions/cache/restore@v4 @@ -139,6 +163,7 @@ jobs: - name: Install dependencies if: | steps.cache-llvm-build.outputs.cache-hit != 'true' || + steps.cache-stablehlo-build.outputs.cache-hit != 'true' || steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | cat /etc/dnf.conf | sed "s/\[main\]/\[main\]\ntimeout=5/g" > /etc/dnf.conf @@ -175,6 +200,30 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' + run: | + export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + cmake -S mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + -DSTABLEHLO_ENABLE_LLD=ON \ + -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ + -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_CXX_VISIBILITY_PRESET=default + + cmake --build $GITHUB_WORKSPACE/stablehlo-build + + - name: Save Stablehlo Build + id: save-stablehlo-build + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' + uses: actions/cache/save@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | @@ -247,6 +296,23 @@ jobs: key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.11-wheel-build fail-on-cache-miss: True + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source + enableCrossOsArchive: True + fail-on-cache-miss: True + + - name: Get Cached Stablehlo Build + id: cache-stablehlo-build + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + fail-on-cache-miss: True + - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -285,13 +351,6 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc - - name: Clone Stablehlo Submodule - uses: actions/checkout@v4 - with: - repository: openxla/stablehlo - ref: ${{ needs.constants.outputs.stablehlo_version }} - path: mlir/stablehlo - # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -304,6 +363,8 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ + -DSTABLEHLO_LIBS_DIR="$GITHUB_WORKSPACE/stablehlo-build/lib" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -326,6 +387,7 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 957bca3d79..738d3fc31a 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -103,6 +103,14 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-container-source enableCrossOsArchive: True + - name: Cache Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache@v4 + with: + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source + enableCrossOsArchive: True + - name: Cache Enzyme Source id: cache-enzyme-source uses: actions/cache@v4 @@ -125,6 +133,14 @@ jobs: cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo + - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -147,6 +163,14 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build + - name: Check Stablehlo Build Cache + id: cache-stablehlo-build + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + lookup-only: True + - name: Check Enzyme Build Cache id: cache-enzyme-build uses: actions/cache/restore@v4 @@ -158,6 +182,7 @@ jobs: - name: Install dependencies (AlmaLinux) if: | steps.cache-llvm-build.outputs.cache-hit != 'true' || + steps.cache-stablehlo-build.outputs.cache-hit != 'true' || steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | # Reduce wait time for repos not responding @@ -197,6 +222,30 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{matrix.python_version}}-wheel-build + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' + run: | + export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + cmake -S mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + -DSTABLEHLO_ENABLE_LLD=ON \ + -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ + -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_CXX_VISIBILITY_PRESET=default + + cmake --build $GITHUB_WORKSPACE/stablehlo-build + + - name: Save Stablehlo Build + id: save-stablehlo-build + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' + uses: actions/cache/save@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | @@ -269,6 +318,23 @@ jobs: key: ${{ matrix.container_img }}-llvm-${{ needs.constants.outputs.llvm_version }}-3.11-wheel-build fail-on-cache-miss: True + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source + enableCrossOsArchive: True + fail-on-cache-miss: True + + - name: Get Cached Stablehlo Build + id: cache-stablehlo-build + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + fail-on-cache-miss: True + - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -309,13 +375,6 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc - - name: Clone Stablehlo Submodule - uses: actions/checkout@v4 - with: - repository: openxla/stablehlo - ref: ${{ needs.constants.outputs.stablehlo_version }} - path: mlir/stablehlo - # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -328,6 +387,8 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ + -DSTABLEHLO_LIBS_DIR="$GITHUB_WORKSPACE/stablehlo-build/lib" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -350,6 +411,7 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 6031c5c09f..5d0adfc2c9 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -89,6 +89,14 @@ jobs: key: llvm-${{ needs.constants.outputs.llvm_version }}-default-source enableCrossOsArchive: True + - name: Cache Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache@v4 + with: + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source + enableCrossOsArchive: True + - name: Cache Enzyme Source id: cache-enzyme-source uses: actions/cache@v4 @@ -111,6 +119,14 @@ jobs: cd $GITHUB_WORKSPACE/mlir/llvm-project git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + - name: Clone Stablehlo Submodule + if: steps.cache-stablehlo-source.outputs.cache-hit != 'true' + uses: actions/checkout@v4 + with: + repository: openxla/stablehlo + ref: ${{ needs.constants.outputs.stablehlo_version }} + path: mlir/stablehlo + - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -133,6 +149,14 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build + - name: Check Stablehlo Build Cache + id: cache-stablehlo-build + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + lookup-only: True + - name: Check Enzyme Build Cache id: cache-enzyme-build uses: actions/cache/restore@v4 @@ -169,6 +193,30 @@ jobs: path: ${{ github.workspace }}/llvm-build key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ matrix.python_version }}-wheel-build + - name: Build Stablehlo Dialect + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' + run: | + export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH + cmake -S mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + -DSTABLEHLO_ENABLE_LLD=ON \ + -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ + -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_CXX_VISIBILITY_PRESET=default + + cmake --build $GITHUB_WORKSPACE/stablehlo-build + + - name: Save Stablehlo Build + id: save-stablehlo-build + if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' + uses: actions/cache/save@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' run: | @@ -235,6 +283,23 @@ jobs: key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ needs.constants.outputs.llvm_version }}-${{ needs.constants.outputs.primary_python_version }}-wheel-build fail-on-cache-miss: True + - name: Get Cached Stablehlo Source + id: cache-stablehlo-source + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-container-source + enableCrossOsArchive: True + fail-on-cache-miss: True + + - name: Get Cached Stablehlo Build + id: cache-stablehlo-build + uses: actions/cache/restore@v4 + with: + path: ${{ github.workspace }}/stablehlo-build + key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + fail-on-cache-miss: True + - name: Get Cached Enzyme Source id: cache-enzyme-source uses: actions/cache/restore@v4 @@ -287,13 +352,6 @@ jobs: PYTHON=$(which python${{ matrix.python_version }}) \ make oqc - - name: Clone Stablehlo Submodule - uses: actions/checkout@v4 - with: - repository: openxla/stablehlo - ref: ${{ needs.constants.outputs.stablehlo_version }} - path: mlir/stablehlo - # Build Quantum and Gradient Dialects - name: Build MLIR Dialects run: | @@ -304,6 +362,8 @@ jobs: -DPython3_EXECUTABLE=$(which python${{ matrix.python_version }}) \ -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ + -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ + -DSTABLEHLO_LIBS_DIR="$GITHUB_WORKSPACE/stablehlo-build/lib" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ @@ -325,6 +385,7 @@ jobs: run: | PYTHON=python${{ matrix.python_version }} \ LLVM_BUILD_DIR="$GITHUB_WORKSPACE/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ DIALECTS_BUILD_DIR="$GITHUB_WORKSPACE/quantum-build" \ RT_BUILD_DIR="$GITHUB_WORKSPACE/runtime-build" \ OQC_BUILD_DIR="$GITHUB_WORKSPACE/oqc-build" \ From d15dbe7a6382c56583509612d98f71a115fe8376 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 11:29:44 -0400 Subject: [PATCH 43/63] misc CI --- .github/workflows/check-jax-release.yaml | 4 ++++ .github/workflows/check-pl-compat.yaml | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/.github/workflows/check-jax-release.yaml b/.github/workflows/check-jax-release.yaml index c6651d2f23..37221f4bb4 100644 --- a/.github/workflows/check-jax-release.yaml +++ b/.github/workflows/check-jax-release.yaml @@ -70,6 +70,10 @@ jobs: run: | make llvm + - name: Build MHLO + run: | + make stablehlo + - name: Build Enzyme run: | make enzyme diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml index a6f9c7f842..29962195fd 100644 --- a/.github/workflows/check-pl-compat.yaml +++ b/.github/workflows/check-pl-compat.yaml @@ -73,6 +73,19 @@ jobs: path: llvm-build key: ${{ runner.os }}-llvm-${{ needs.constants.outputs.llvm_version }}-ci-build-gcc fail-on-cache-miss: True + - uses: actions/cache/restore@v4 + if: ${{ inputs.catalyst != 'stable' }} + with: + path: mlir/stablehlo + key: stablehlo-${{ needs.constants.outputs.stablehlo_version }}-default-source + enableCrossOsArchive: True + fail-on-cache-miss: True + - uses: actions/cache/restore@v4 + if: ${{ inputs.catalyst != 'stable' }} + with: + path: stablehlo-build + key: ${{ runner.os }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-ci-build-gcc + fail-on-cache-miss: True - uses: actions/cache/restore@v4 if: ${{ inputs.catalyst != 'stable' }} with: @@ -107,6 +120,7 @@ jobs: ENABLE_LLD=ON \ RT_BUILD_DIR="$(pwd)/runtime-build" \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$(pwd)/stablehlo-build" \ ENZYME_BUILD_DIR="$(pwd)/enzyme-build" \ DIALECTS_BUILD_DIR="$(pwd)/quantum-build" \ ENABLE_OPENQASM=ON \ From ca830cc6b54c9d9ecc560c9579f33a5e297a487d Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 11:56:26 -0400 Subject: [PATCH 44/63] small fix: unify cmake entry as STABLEHLO_BUILD_DIR instead of manually wiring inside cmake CI path is stablehlo-build but regular path is stablehlo/build --- .github/workflows/build-wheel-linux-arm64.yaml | 2 +- .github/workflows/build-wheel-linux-x86_64.yaml | 2 +- .github/workflows/build-wheel-macos-arm64.yaml | 2 +- mlir/CMakeLists.txt | 4 ++-- mlir/Makefile | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index fd90befbcf..214c4031b5 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -364,7 +364,7 @@ jobs: -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ - -DSTABLEHLO_LIBS_DIR="$GITHUB_WORKSPACE/stablehlo-build/lib" \ + -DSTABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 738d3fc31a..2aee0a081b 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -388,7 +388,7 @@ jobs: -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ - -DSTABLEHLO_LIBS_DIR="$GITHUB_WORKSPACE/stablehlo-build/lib" \ + -DSTABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 5d0adfc2c9..84fe77e3fd 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -363,7 +363,7 @@ jobs: -DPython3_NumPy_INCLUDE_DIRS=$(python${{ matrix.python_version }} -c "import numpy as np; print(np.get_include())") \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ -DSTABLEHLO_DIR="$GITHUB_WORKSPACE/mlir/stablehlo" \ - -DSTABLEHLO_LIBS_DIR="$GITHUB_WORKSPACE/stablehlo-build/lib" \ + -DSTABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ -DEnzyme_DIR="$GITHUB_WORKSPACE/enzyme-build" \ -DENZYME_SRC_DIR="$GITHUB_WORKSPACE/mlir/Enzyme" \ -DLLVM_ENABLE_ZLIB=FORCE_ON \ diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 6e1a59a296..4db6fccc9d 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -70,7 +70,7 @@ set(STABLEHLO_LIBS foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) add_library(${STABLEHLO_LIB} STATIC IMPORTED GLOBAL) set_property(TARGET ${STABLEHLO_LIB} PROPERTY - IMPORTED_LOCATION "${STABLEHLO_LIBS_DIR}/lib${STABLEHLO_LIB}.a" + IMPORTED_LOCATION "${STABLEHLO_BUILD_DIR}/lib/lib${STABLEHLO_LIB}.a" ) endforeach() @@ -82,7 +82,7 @@ endforeach() target_include_directories(ExternalStablehloLib INTERFACE ${STABLEHLO_DIR} - ${STABLEHLO_DIR}/build # for the generated .inc files + ${STABLEHLO_BUILD_DIR} # for the generated .inc files ) # Policy CMP0175 was introduced in CMake 4.31 and raises warnings in the upstream CMake modules. diff --git a/mlir/Makefile b/mlir/Makefile index 8aa6f8c91c..13b8ea7cb6 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -172,7 +172,7 @@ dialects: -DENZYME_SRC_DIR=$(MK_DIR)/Enzyme \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ -DSTABLEHLO_DIR=$(MK_DIR)/stablehlo \ - -DSTABLEHLO_LIBS_DIR=$(STABLEHLO_BUILD_DIR)/lib \ + -DSTABLEHLO_BUILD_DIR=$(STABLEHLO_BUILD_DIR) \ -DRUNTIME_LIB_DIR=$(RT_BUILD_DIR)/lib \ -DMLIR_LIB_DIR=$(LLVM_BUILD_DIR)/lib \ -DCMAKE_C_COMPILER=$(C_COMPILER) \ From 72bf8dee14ff74cfd7595cd82e7c35e2bf4f212a Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 12:27:01 -0400 Subject: [PATCH 45/63] mac does not have lld --- .github/workflows/build-wheel-macos-arm64.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 84fe77e3fd..f147684aa1 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -201,7 +201,7 @@ jobs: -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DSTABLEHLO_ENABLE_LLD=ON \ + -DSTABLEHLO_ENABLE_LLD=OFF \ -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ From 9779290b896d44a36593d0270ca33bbca18d4402 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 14:42:36 -0400 Subject: [PATCH 46/63] small fix for mac wheel cache name --- .github/workflows/build-wheel-macos-arm64.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index f147684aa1..f755d2aa97 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -154,7 +154,7 @@ jobs: uses: actions/cache/restore@v4 with: path: ${{ github.workspace }}/stablehlo-build - key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build lookup-only: True - name: Check Enzyme Build Cache @@ -215,7 +215,7 @@ jobs: uses: actions/cache/save@v4 with: path: ${{ github.workspace }}/stablehlo-build - key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build - name: Build Enzyme if: steps.cache-enzyme-build.outputs.cache-hit != 'true' @@ -297,7 +297,7 @@ jobs: uses: actions/cache/restore@v4 with: path: ${{ github.workspace }}/stablehlo-build - key: ${{ matrix.container_img }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build + key: ${{ runner.os }}-${{ runner.arch }}-stablehlo-${{ needs.constants.outputs.stablehlo_version }}-wheel-build fail-on-cache-miss: True - name: Get Cached Enzyme Source From 4a7b902dc5f1301e207f55bcc0a08f5042926169 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 14:44:20 -0400 Subject: [PATCH 47/63] go back to cpp 20: no longer embedded so no longer constrained by stablehlo still being on 17 --- mlir/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 4db6fccc9d..66d9790375 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -4,8 +4,7 @@ project(Catalyst LANGUAGES CXX C) set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -# stablehlo is still on cpp17 -set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to") +set(CMAKE_CXX_STANDARD 20 CACHE STRING "C++ standard to conform to") set(CMAKE_CXX_STANDARD_REQUIRED ON) # Required so as not to always use the cached option from the mlir build. From 5dc59ab60fa4cc21beb0c7fcbb3f103e8aa2a184 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 15:09:27 -0400 Subject: [PATCH 48/63] move scatter-lowering and hlo-custom-call-lowering from catalyst dialect to stablehlo folder --- mlir/include/Catalyst/Transforms/Passes.h | 2 -- mlir/include/Catalyst/Transforms/Passes.td | 26 ----------------- mlir/include/stablehlo/Passes.h | 2 ++ mlir/include/stablehlo/Passes.td | 29 +++++++++++++++++++ mlir/include/stablehlo/Patterns.h | 27 +++++++++++++++++ mlir/lib/Catalyst/Transforms/CMakeLists.txt | 5 ---- mlir/lib/stablehlo/CMakeLists.txt | 6 +++- .../HloCustomCallPatterns.cpp | 3 +- .../ScatterPatterns.cpp | 1 - .../hlo_custom_call_lowering.cpp | 5 ++-- .../scatter_lowering.cpp | 5 ++-- 11 files changed, 70 insertions(+), 41 deletions(-) create mode 100644 mlir/include/stablehlo/Patterns.h rename mlir/lib/{Catalyst/Transforms => stablehlo}/HloCustomCallPatterns.cpp (99%) rename mlir/lib/{Catalyst/Transforms => stablehlo}/ScatterPatterns.cpp (99%) rename mlir/lib/{Catalyst/Transforms => stablehlo}/hlo_custom_call_lowering.cpp (95%) rename mlir/lib/{Catalyst/Transforms => stablehlo}/scatter_lowering.cpp (94%) diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index f7872961ad..4da01be727 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -29,13 +29,11 @@ std::unique_ptr createCatalystConversionPass(); std::unique_ptr createDetensorizeSCFPass(); std::unique_ptr createDisableAssertionPass(); std::unique_ptr createGEPInboundsPass(); -std::unique_ptr createHloCustomCallLoweringPass(); std::unique_ptr createInlineNestedModulePass(); std::unique_ptr createMemrefCopyToLinalgCopyPass(); std::unique_ptr createMemrefToLLVMWithTBAAPass(); std::unique_ptr createQnodeToAsyncLoweringPass(); std::unique_ptr createRegisterInactiveCallbackPass(); -std::unique_ptr createScatterLoweringPass(); std::unique_ptr createSplitMultipleTapesPass(); void registerAllCatalystPasses(); diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index e731cecddc..e86130286a 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -55,32 +55,6 @@ def CatalystConversionPass : Pass<"convert-catalyst-to-llvm"> { let constructor = "catalyst::createCatalystConversionPass()"; } -def ScatterLoweringPass : Pass<"scatter-lowering"> { - let summary = "Lower scatter op from Stable HLO to loops."; - - let dependentDialects = [ - "mlir::func::FuncDialect", - "index::IndexDialect", - "stablehlo::StablehloDialect", - "tensor::TensorDialect", - "scf::SCFDialect" - ]; - - let constructor = "catalyst::createScatterLoweringPass()"; -} - -def HloCustomCallLoweringPass : Pass<"hlo-custom-call-lowering"> { - let summary = "Lower custom calls op from Stable HLO to CallOp."; - - let dependentDialects = [ - "index::IndexDialect", - "mlir::func::FuncDialect", - "catalyst::CatalystDialect", - ]; - - let constructor = "catalyst::createHloCustomCallLoweringPass()"; -} - def QnodeToAsyncLoweringPass : Pass<"qnode-to-async-lowering"> { let summary = "Lower Qnode func and call operations to async func and call operations."; diff --git a/mlir/include/stablehlo/Passes.h b/mlir/include/stablehlo/Passes.h index d06976f017..c15ed04114 100644 --- a/mlir/include/stablehlo/Passes.h +++ b/mlir/include/stablehlo/Passes.h @@ -19,6 +19,8 @@ #include "mlir/Pass/Pass.h" namespace catalyst { +std::unique_ptr createHloCustomCallLoweringPass(); +std::unique_ptr createScatterLoweringPass(); std::unique_ptr createStablehloLegalizeSortPass(); std::unique_ptr createStablehloLegalizeToStdPass(); std::unique_ptr createStablehloLegalizeControlFlowPass(); diff --git a/mlir/include/stablehlo/Passes.td b/mlir/include/stablehlo/Passes.td index 64fd653ed1..bfa517c406 100644 --- a/mlir/include/stablehlo/Passes.td +++ b/mlir/include/stablehlo/Passes.td @@ -36,6 +36,35 @@ include "mlir/Pass/PassBase.td" +// -------------------- Catalyst's own hlo-related passes ------------------------ // + +def ScatterLoweringPass : Pass<"scatter-lowering"> { + let summary = "Lower scatter op from Stable HLO to loops."; + + let dependentDialects = [ + "mlir::func::FuncDialect", + "index::IndexDialect", + "stablehlo::StablehloDialect", + "tensor::TensorDialect", + "scf::SCFDialect" + ]; + + let constructor = "catalyst::createScatterLoweringPass()"; +} + +def HloCustomCallLoweringPass : Pass<"hlo-custom-call-lowering"> { + let summary = "Lower custom calls op from Stable HLO to CallOp."; + + let dependentDialects = [ + "index::IndexDialect", + "mlir::func::FuncDialect", + "catalyst::CatalystDialect", + ]; + + let constructor = "catalyst::createHloCustomCallLoweringPass()"; +} + +// -------------------- upstream mhlo passes removed in stablehlo ------------------------ // // stablehlo legalize sort pass. def StablehloLegalizeSortPass : Pass<"stablehlo-legalize-sort", "func::FuncOp"> { let summary = "Legalize from Stablehlo sort to SCF control flow."; diff --git a/mlir/include/stablehlo/Patterns.h b/mlir/include/stablehlo/Patterns.h new file mode 100644 index 0000000000..8caa0ecc1c --- /dev/null +++ b/mlir/include/stablehlo/Patterns.h @@ -0,0 +1,27 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace catalyst { + +void populateScatterPatterns(mlir::RewritePatternSet &); + +void populateHloCustomCallPatterns(mlir::RewritePatternSet &); + +} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index a5342e3f0a..716be8d515 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -13,8 +13,6 @@ file(GLOB SRC DisableAssertionPatterns.cpp GEPInboundsPass.cpp GEPInboundsPatterns.cpp - hlo_custom_call_lowering.cpp - HloCustomCallPatterns.cpp InlineNestedModules.cpp MemrefCopyToLinalgCopyPass.cpp MemrefCopyToLinalgCopyPatterns.cpp @@ -22,8 +20,6 @@ file(GLOB SRC QnodeToAsyncPatterns.cpp RegisterAllPasses.cpp RegisterInactiveCallbackPass.cpp - scatter_lowering.cpp - ScatterPatterns.cpp SplitMultipleTapes.cpp TBAAPatterns.cpp TBAATagsPass.cpp @@ -45,4 +41,3 @@ target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include ${CMAKE_BINARY_DIR}/include) -target_link_libraries(${LIBRARY_NAME} PRIVATE ExternalStablehloLib) diff --git a/mlir/lib/stablehlo/CMakeLists.txt b/mlir/lib/stablehlo/CMakeLists.txt index 4e035f0114..54c079d311 100644 --- a/mlir/lib/stablehlo/CMakeLists.txt +++ b/mlir/lib/stablehlo/CMakeLists.txt @@ -1,6 +1,10 @@ set(LIBRARY_NAME catalyst-stablehlo-transforms) file(GLOB SRC + hlo_custom_call_lowering.cpp + HloCustomCallPatterns.cpp + scatter_lowering.cpp + ScatterPatterns.cpp stablehlo_legalize_control_flow.cpp stablehlo_legalize_sort.cpp stablehlo_legalize_to_std.cpp @@ -14,8 +18,8 @@ set(LIBS ) set(DEPENDS - MLIRCatalystPassIncGen STABLEHLOCatalystPassIncGen + MLIRCatalystPassIncGen MLIRStablehloLegalizeToStandardIncGen ) diff --git a/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp b/mlir/lib/stablehlo/HloCustomCallPatterns.cpp similarity index 99% rename from mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp rename to mlir/lib/stablehlo/HloCustomCallPatterns.cpp index 3b1fe91149..2168cc46e9 100644 --- a/mlir/lib/Catalyst/Transforms/HloCustomCallPatterns.cpp +++ b/mlir/lib/stablehlo/HloCustomCallPatterns.cpp @@ -17,9 +17,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" -#include "llvm/Support/Debug.h" - #include "stablehlo/dialect/StablehloOps.h" +#include "llvm/Support/Debug.h" #include "Catalyst/IR/CatalystOps.h" diff --git a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp b/mlir/lib/stablehlo/ScatterPatterns.cpp similarity index 99% rename from mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp rename to mlir/lib/stablehlo/ScatterPatterns.cpp index 56312a719b..0887be843b 100644 --- a/mlir/lib/Catalyst/Transforms/ScatterPatterns.cpp +++ b/mlir/lib/stablehlo/ScatterPatterns.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" - #include "stablehlo/dialect/StablehloOps.h" using namespace mlir; diff --git a/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp b/mlir/lib/stablehlo/hlo_custom_call_lowering.cpp similarity index 95% rename from mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp rename to mlir/lib/stablehlo/hlo_custom_call_lowering.cpp index b2bea6cb65..b4c3b2a79a 100644 --- a/mlir/lib/Catalyst/Transforms/hlo_custom_call_lowering.cpp +++ b/mlir/lib/stablehlo/hlo_custom_call_lowering.cpp @@ -27,7 +27,8 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/Transforms/Patterns.h" +#include "stablehlo/Passes.h" +#include "stablehlo/Patterns.h" using namespace llvm; using namespace mlir; @@ -35,7 +36,7 @@ using namespace catalyst; namespace catalyst { #define GEN_PASS_DEF_HLOCUSTOMCALLLOWERINGPASS -#include "Catalyst/Transforms/Passes.h.inc" +#include "stablehlo/Passes.h.inc" struct HloCustomCallLoweringPass : impl::HloCustomCallLoweringPassBase { using HloCustomCallLoweringPassBase::HloCustomCallLoweringPassBase; diff --git a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp b/mlir/lib/stablehlo/scatter_lowering.cpp similarity index 94% rename from mlir/lib/Catalyst/Transforms/scatter_lowering.cpp rename to mlir/lib/stablehlo/scatter_lowering.cpp index 8e5cca4989..fd4f6f7487 100644 --- a/mlir/lib/Catalyst/Transforms/scatter_lowering.cpp +++ b/mlir/lib/stablehlo/scatter_lowering.cpp @@ -28,7 +28,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "Catalyst/Transforms/Patterns.h" +#include "stablehlo/Passes.h" +#include "stablehlo/Patterns.h" using namespace llvm; using namespace mlir; @@ -36,7 +37,7 @@ using namespace catalyst; namespace catalyst { #define GEN_PASS_DEF_SCATTERLOWERINGPASS -#include "Catalyst/Transforms/Passes.h.inc" +#include "stablehlo/Passes.h.inc" struct ScatterLoweringPass : impl::ScatterLoweringPassBase { using ScatterLoweringPassBase::ScatterLoweringPassBase; From c910f32983c213c4a254960a6d271211bd808b73 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 15:12:40 -0400 Subject: [PATCH 49/63] small follow up --- mlir/include/Catalyst/Transforms/Patterns.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlir/include/Catalyst/Transforms/Patterns.h b/mlir/include/Catalyst/Transforms/Patterns.h index 6bbf3150ff..53ebc04245 100644 --- a/mlir/include/Catalyst/Transforms/Patterns.h +++ b/mlir/include/Catalyst/Transforms/Patterns.h @@ -21,10 +21,6 @@ namespace catalyst { -void populateScatterPatterns(mlir::RewritePatternSet &); - -void populateHloCustomCallPatterns(mlir::RewritePatternSet &); - void populateQnodeToAsyncPatterns(mlir::RewritePatternSet &); void populateDisableAssertionPatterns(mlir::RewritePatternSet &); From 2f96459090a99958a0a34551d2987199e93b90ce Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 1 Aug 2025 15:24:00 -0400 Subject: [PATCH 50/63] add ${{ github.workspace }} to paths in wheels script --- .github/workflows/build-wheel-linux-arm64.yaml | 4 ++-- .github/workflows/build-wheel-linux-x86_64.yaml | 4 ++-- .github/workflows/build-wheel-macos-arm64.yaml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 214c4031b5..097fa4dd56 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -120,7 +120,7 @@ jobs: with: repository: openxla/stablehlo ref: ${{ needs.constants.outputs.stablehlo_version }} - path: mlir/stablehlo + path: ${{ github.workspace }}/mlir/stablehlo - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -204,7 +204,7 @@ jobs: if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - cmake -S mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ + cmake -S $GITHUB_WORKSPACE/mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 2aee0a081b..352066d6e5 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -139,7 +139,7 @@ jobs: with: repository: openxla/stablehlo ref: ${{ needs.constants.outputs.stablehlo_version }} - path: mlir/stablehlo + path: ${{ github.workspace }}/mlir/stablehlo - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -226,7 +226,7 @@ jobs: if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - cmake -S mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ + cmake -S $GITHUB_WORKSPACE/mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index f755d2aa97..07743d018b 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -125,7 +125,7 @@ jobs: with: repository: openxla/stablehlo ref: ${{ needs.constants.outputs.stablehlo_version }} - path: mlir/stablehlo + path: ${{ github.workspace }}/mlir/stablehlo - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -197,7 +197,7 @@ jobs: if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - cmake -S mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ + cmake -S $GITHUB_WORKSPACE/mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ From 8cc592ad0509691011757758ecbb89fe79d991c1 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 12:05:06 -0400 Subject: [PATCH 51/63] rename custom `stablehlo` folder to `hlo-extensions` --- mlir/include/CMakeLists.txt | 2 +- .../{stablehlo => hlo-extensions}/CMakeLists.txt | 0 mlir/include/{stablehlo => hlo-extensions}/Passes.h | 0 mlir/include/{stablehlo => hlo-extensions}/Passes.td | 0 .../include/{stablehlo => hlo-extensions}/Patterns.h | 0 .../stablehlo_legalize_to_standard_patterns.td | 0 mlir/lib/CMakeLists.txt | 2 +- mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp | 2 +- mlir/lib/Driver/Pipelines.cpp | 2 +- .../lib/{stablehlo => hlo-extensions}/CMakeLists.txt | 0 .../HloCustomCallPatterns.cpp | 2 +- .../ScatterPatterns.cpp | 0 .../hlo_custom_call_lowering.cpp | 12 +++++------- .../scatter_lowering.cpp | 12 +++++------- .../stablehlo_legalize_control_flow.cpp | 6 +++--- .../stablehlo_legalize_sort.cpp | 6 +++--- .../stablehlo_legalize_to_std.cpp | 7 +++---- mlir/tools/quantum-opt/quantum-opt.cpp | 6 ++---- 18 files changed, 26 insertions(+), 33 deletions(-) rename mlir/include/{stablehlo => hlo-extensions}/CMakeLists.txt (100%) rename mlir/include/{stablehlo => hlo-extensions}/Passes.h (100%) rename mlir/include/{stablehlo => hlo-extensions}/Passes.td (100%) rename mlir/include/{stablehlo => hlo-extensions}/Patterns.h (100%) rename mlir/include/{stablehlo => hlo-extensions}/stablehlo_legalize_to_standard_patterns.td (100%) rename mlir/lib/{stablehlo => hlo-extensions}/CMakeLists.txt (100%) rename mlir/lib/{stablehlo => hlo-extensions}/HloCustomCallPatterns.cpp (100%) rename mlir/lib/{stablehlo => hlo-extensions}/ScatterPatterns.cpp (100%) rename mlir/lib/{stablehlo => hlo-extensions}/hlo_custom_call_lowering.cpp (94%) rename mlir/lib/{stablehlo => hlo-extensions}/scatter_lowering.cpp (94%) rename mlir/lib/{stablehlo => hlo-extensions}/stablehlo_legalize_control_flow.cpp (99%) rename mlir/lib/{stablehlo => hlo-extensions}/stablehlo_legalize_sort.cpp (99%) rename mlir/lib/{stablehlo => hlo-extensions}/stablehlo_legalize_to_std.cpp (98%) diff --git a/mlir/include/CMakeLists.txt b/mlir/include/CMakeLists.txt index 4dc4680506..85180214a9 100644 --- a/mlir/include/CMakeLists.txt +++ b/mlir/include/CMakeLists.txt @@ -1,9 +1,9 @@ add_subdirectory(Catalyst) add_subdirectory(Gradient) +add_subdirectory(hlo-extensions) add_subdirectory(Ion) add_subdirectory(MBQC) add_subdirectory(Mitigation) add_subdirectory(QEC) add_subdirectory(Quantum) -add_subdirectory(stablehlo) add_subdirectory(Test) diff --git a/mlir/include/stablehlo/CMakeLists.txt b/mlir/include/hlo-extensions/CMakeLists.txt similarity index 100% rename from mlir/include/stablehlo/CMakeLists.txt rename to mlir/include/hlo-extensions/CMakeLists.txt diff --git a/mlir/include/stablehlo/Passes.h b/mlir/include/hlo-extensions/Passes.h similarity index 100% rename from mlir/include/stablehlo/Passes.h rename to mlir/include/hlo-extensions/Passes.h diff --git a/mlir/include/stablehlo/Passes.td b/mlir/include/hlo-extensions/Passes.td similarity index 100% rename from mlir/include/stablehlo/Passes.td rename to mlir/include/hlo-extensions/Passes.td diff --git a/mlir/include/stablehlo/Patterns.h b/mlir/include/hlo-extensions/Patterns.h similarity index 100% rename from mlir/include/stablehlo/Patterns.h rename to mlir/include/hlo-extensions/Patterns.h diff --git a/mlir/include/stablehlo/stablehlo_legalize_to_standard_patterns.td b/mlir/include/hlo-extensions/stablehlo_legalize_to_standard_patterns.td similarity index 100% rename from mlir/include/stablehlo/stablehlo_legalize_to_standard_patterns.td rename to mlir/include/hlo-extensions/stablehlo_legalize_to_standard_patterns.td diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index f669d222c8..687046ef64 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -2,10 +2,10 @@ add_subdirectory(CAPI) add_subdirectory(Catalyst) add_subdirectory(Driver) add_subdirectory(Gradient) +add_subdirectory(hlo-extensions) add_subdirectory(Ion) add_subdirectory(MBQC) add_subdirectory(Mitigation) add_subdirectory(QEC) add_subdirectory(Quantum) -add_subdirectory(stablehlo) add_subdirectory(Test) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 9e3e1be7b8..1ef631247b 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -22,7 +22,7 @@ #include "QEC/Transforms/Passes.h" #include "Quantum/Transforms/Passes.h" #include "Test/Transforms/Passes.h" -#include "stablehlo/Passes.h" +#include "hlo-extensions/Passes.h" void catalyst::registerAllCatalystPasses() { diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index e5062457c7..ccf4e4ab4f 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -31,7 +31,7 @@ #include "Mitigation/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/Passes.h" -#include "stablehlo/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; namespace catalyst { diff --git a/mlir/lib/stablehlo/CMakeLists.txt b/mlir/lib/hlo-extensions/CMakeLists.txt similarity index 100% rename from mlir/lib/stablehlo/CMakeLists.txt rename to mlir/lib/hlo-extensions/CMakeLists.txt diff --git a/mlir/lib/stablehlo/HloCustomCallPatterns.cpp b/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp similarity index 100% rename from mlir/lib/stablehlo/HloCustomCallPatterns.cpp rename to mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp index 2168cc46e9..9da9e413f3 100644 --- a/mlir/lib/stablehlo/HloCustomCallPatterns.cpp +++ b/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp @@ -14,11 +14,11 @@ #define DEBUG_TYPE "scatter" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "stablehlo/dialect/StablehloOps.h" -#include "llvm/Support/Debug.h" #include "Catalyst/IR/CatalystOps.h" diff --git a/mlir/lib/stablehlo/ScatterPatterns.cpp b/mlir/lib/hlo-extensions/ScatterPatterns.cpp similarity index 100% rename from mlir/lib/stablehlo/ScatterPatterns.cpp rename to mlir/lib/hlo-extensions/ScatterPatterns.cpp diff --git a/mlir/lib/stablehlo/hlo_custom_call_lowering.cpp b/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp similarity index 94% rename from mlir/lib/stablehlo/hlo_custom_call_lowering.cpp rename to mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp index b4c3b2a79a..11b3ef4e4d 100644 --- a/mlir/lib/stablehlo/hlo_custom_call_lowering.cpp +++ b/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp @@ -17,18 +17,16 @@ #include #include "llvm/Support/Debug.h" - -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" #include "Catalyst/IR/CatalystDialect.h" -#include "stablehlo/Passes.h" -#include "stablehlo/Patterns.h" +#include "hlo-extensions/Passes.h" +#include "hlo-extensions/Patterns.h" using namespace llvm; using namespace mlir; @@ -36,7 +34,7 @@ using namespace catalyst; namespace catalyst { #define GEN_PASS_DEF_HLOCUSTOMCALLLOWERINGPASS -#include "stablehlo/Passes.h.inc" +#include "hlo-extensions/Passes.h.inc" struct HloCustomCallLoweringPass : impl::HloCustomCallLoweringPassBase { using HloCustomCallLoweringPassBase::HloCustomCallLoweringPassBase; diff --git a/mlir/lib/stablehlo/scatter_lowering.cpp b/mlir/lib/hlo-extensions/scatter_lowering.cpp similarity index 94% rename from mlir/lib/stablehlo/scatter_lowering.cpp rename to mlir/lib/hlo-extensions/scatter_lowering.cpp index fd4f6f7487..1b670132c3 100644 --- a/mlir/lib/stablehlo/scatter_lowering.cpp +++ b/mlir/lib/hlo-extensions/scatter_lowering.cpp @@ -17,19 +17,17 @@ #include #include "llvm/Support/Debug.h" - -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/Passes.h" -#include "stablehlo/Passes.h" -#include "stablehlo/Patterns.h" +#include "hlo-extensions/Passes.h" +#include "hlo-extensions/Patterns.h" using namespace llvm; using namespace mlir; @@ -37,7 +35,7 @@ using namespace catalyst; namespace catalyst { #define GEN_PASS_DEF_SCATTERLOWERINGPASS -#include "stablehlo/Passes.h.inc" +#include "hlo-extensions/Passes.h.inc" struct ScatterLoweringPass : impl::ScatterLoweringPassBase { using ScatterLoweringPassBase::ScatterLoweringPassBase; diff --git a/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp similarity index 99% rename from mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp rename to mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp index 0538e9765c..86020cfb60 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_control_flow.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp @@ -41,6 +41,7 @@ limitations under the License. #include #include +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project @@ -56,9 +57,8 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" -#include "llvm/Support/Casting.h" -#include "stablehlo/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; using namespace stablehlo; @@ -68,7 +68,7 @@ namespace catalyst { #define GEN_PASS_DEF_STABLEHLOLEGALIZECONTROLFLOWPASS #define GEN_PASS_DECL_STABLEHLOLEGALIZECONTROLFLOWPASS -#include "stablehlo/Passes.h.inc" +#include "hlo-extensions/Passes.h.inc" } // namespace catalyst diff --git a/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp similarity index 99% rename from mlir/lib/stablehlo/stablehlo_legalize_sort.cpp rename to mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp index dc2b296ab6..ad26ab6ea7 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_sort.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp @@ -41,6 +41,7 @@ limitations under the License. #include #include +#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -63,9 +64,8 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "stablehlo/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; using namespace stablehlo; @@ -75,7 +75,7 @@ namespace catalyst { #define GEN_PASS_DEF_STABLEHLOLEGALIZESORTPASS #define GEN_PASS_DECL_STABLEHLOLEGALIZESORTPASS -#include "stablehlo/Passes.h.inc" +#include "hlo-extensions/Passes.h.inc" } // namespace catalyst diff --git a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_to_std.cpp similarity index 98% rename from mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp rename to mlir/lib/hlo-extensions/stablehlo_legalize_to_std.cpp index bfab0f4673..1cec1aa55a 100644 --- a/mlir/lib/stablehlo/stablehlo_legalize_to_std.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_to_std.cpp @@ -42,7 +42,6 @@ limitations under the License. #include #include -// #include "mhlo/transforms/rewriters.h" // (??) #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -53,7 +52,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" -#include "stablehlo/Passes.h" +#include "hlo-extensions/Passes.h" using namespace mlir; using namespace stablehlo; @@ -63,8 +62,8 @@ namespace catalyst { #define GEN_PASS_DEF_STABLEHLOLEGALIZETOSTANDARDPASS #define GEN_PASS_DECL_STABLEHLOLEGALIZETOSTANDARDPASS -#include "stablehlo/Passes.h.inc" -#include "stablehlo/generated_stablehlo_legalize_to_standard.cpp.inc" +#include "hlo-extensions/Passes.h.inc" +#include "hlo-extensions/generated_stablehlo_legalize_to_standard.cpp.inc" } // namespace catalyst diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index a1a3caebd4..01fe5abd02 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -16,16 +16,14 @@ #include // ifstream #include //regex +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "stablehlo/dialect/Register.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/raw_ostream.h" - #include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/integrations/c/StablehloPasses.h" From c915e593df7fc9f5916df5dfb1231c692b4c0d81 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 12:09:16 -0400 Subject: [PATCH 52/63] format --- mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp | 2 +- mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp | 2 +- mlir/lib/hlo-extensions/scatter_lowering.cpp | 2 +- mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp | 2 +- mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp | 2 +- mlir/tools/quantum-opt/quantum-opt.cpp | 4 ++-- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp b/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp index 9da9e413f3..2168cc46e9 100644 --- a/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp +++ b/mlir/lib/hlo-extensions/HloCustomCallPatterns.cpp @@ -14,11 +14,11 @@ #define DEBUG_TYPE "scatter" -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "stablehlo/dialect/StablehloOps.h" +#include "llvm/Support/Debug.h" #include "Catalyst/IR/CatalystOps.h" diff --git a/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp b/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp index 11b3ef4e4d..cb0a97dc8e 100644 --- a/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp +++ b/mlir/lib/hlo-extensions/hlo_custom_call_lowering.cpp @@ -16,13 +16,13 @@ #include -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" +#include "llvm/Support/Debug.h" #include "Catalyst/IR/CatalystDialect.h" #include "hlo-extensions/Passes.h" diff --git a/mlir/lib/hlo-extensions/scatter_lowering.cpp b/mlir/lib/hlo-extensions/scatter_lowering.cpp index 1b670132c3..1726ce8a62 100644 --- a/mlir/lib/hlo-extensions/scatter_lowering.cpp +++ b/mlir/lib/hlo-extensions/scatter_lowering.cpp @@ -16,7 +16,6 @@ #include -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -25,6 +24,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" +#include "llvm/Support/Debug.h" #include "hlo-extensions/Passes.h" #include "hlo-extensions/Patterns.h" diff --git a/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp index 86020cfb60..03ad07f78d 100644 --- a/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_control_flow.cpp @@ -41,7 +41,6 @@ limitations under the License. #include #include -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project @@ -57,6 +56,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" +#include "llvm/Support/Casting.h" #include "hlo-extensions/Passes.h" diff --git a/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp b/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp index ad26ab6ea7..29162dccbd 100644 --- a/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp +++ b/mlir/lib/hlo-extensions/stablehlo_legalize_sort.cpp @@ -41,7 +41,6 @@ limitations under the License. #include #include -#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -64,6 +63,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" #include "hlo-extensions/Passes.h" diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index 01fe5abd02..1d252733ea 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -16,8 +16,6 @@ #include // ifstream #include //regex -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/DialectRegistry.h" @@ -29,6 +27,8 @@ #include "stablehlo/integrations/c/StablehloPasses.h" #include "stablehlo/transforms/Passes.h" #include "stablehlo/transforms/optimization/Passes.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" From 8a3a4846c4bda6f3f2ec83792a98980c8d1d647b Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 12:11:38 -0400 Subject: [PATCH 53/63] remove set_dep_versions.py --- .github/workflows/set_dep_versions.py | 82 --------------------------- 1 file changed, 82 deletions(-) delete mode 100644 .github/workflows/set_dep_versions.py diff --git a/.github/workflows/set_dep_versions.py b/.github/workflows/set_dep_versions.py deleted file mode 100644 index d2340d4ea4..0000000000 --- a/.github/workflows/set_dep_versions.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2022-2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This module computes commit hashes for LLVM and MLIR-HLO based on a given JAX version. -""" - -# pylint: disable=line-too-long -# pylint: disable=anomalous-backslash-in-string -# pylint: disable=consider-using-with - -import os -import re -import sys - -import requests - -jax_version = sys.argv[1] -dep_versions_path = os.path.join(os.path.dirname(__file__), "../../.dep-versions") -catalyst_init_path = os.path.join(os.path.dirname(__file__), "../../frontend/catalyst/__init__.py") - -assert os.path.isfile(dep_versions_path) -assert os.path.isfile(catalyst_init_path) - -url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/WORKSPACE" -response = requests.get(url) -match = re.search(r'strip_prefix = "xla-([a-zA-Z0-9]*)"', response.text) -if not match: - url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/third_party/xla/workspace.bzl" - response = requests.get(url) - match = re.search(r'XLA_COMMIT = "([a-zA-Z0-9]*)"', response.text) -xla_commit = match.group(1) - -url = f"https://raw.githubusercontent.com/openxla/xla/{xla_commit}/third_party/llvm/workspace.bzl" -response = requests.get(url) -match = re.search(r'LLVM_COMMIT = "([a-zA-Z0-9]*)"', response.text) -llvm_commit = match.group(1) - -# If the XLA commit is an "Integrate LLVM" commit we need to get the piper_id directly from there -# to look up the corresponding mlir-hlo commit. -url = f"https://api.github.com/repos/openxla/xla/commits?sha={xla_commit}&per_page=1" -response = requests.get(url).json() -match = re.search(r"Integrate LLVM", response[0]["commit"]["message"]) -if match: - match = re.search(r"PiperOrigin-RevId: ([0-9]*)", response[0]["commit"]["message"]) - piper_id = match.group(1) -else: - # Otherwise, we get the last commit in the XLA repository that touched the mlir-hlo files, and - # get its piper_id to get the same commit in the standalone mlir-hlo repo. - url = f"https://api.github.com/repos/openxla/xla/commits?sha={xla_commit}&path=xla/mlir_hlo&per_page=1" - response = requests.get(url).json() - xla_hlo_commit = response[0]["sha"] - match = re.search(r"PiperOrigin-RevId: ([0-9]*)", response[0]["commit"]["message"]) - piper_id = match.group(1) - -url = f"https://api.github.com/search/commits?q=repo:tensorflow/mlir-hlo+{piper_id}" -response = requests.get(url).json() -hlo_commit = response["items"][0]["sha"] - -quote = '"' -# Update each version using sed -cmds = [ - f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}", - f"sed -i '' 's/^stablehlo=.*/stablehlo={hlo_commit}/' {dep_versions_path}", - f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}", - # Update jaxlib version in __init__.py - rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}", -] - -for cmd in cmds: - res = os.system(cmd) - assert res == 0 From b3a24d0d8a4493254b4135cecc1ba17202867c35 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 13:45:54 -0400 Subject: [PATCH 54/63] changelog --- doc/releases/changelog-dev.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index bd9ea567d8..a32d2b73fa 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -42,13 +42,15 @@ * The JAX version used by Catalyst is updated to 0.6.2. [(#1897)](https://github.com/PennyLaneAI/catalyst/pull/1897) -* The version of LLVM, mlir-hlo, and Enzyme used by Catalyst has been updated. +* The version of LLVM and Enzyme used by Catalyst has been updated. + The mlir-hlo dependency has been replaced with stablehlo. [(#1916)](https://github.com/PennyLaneAI/catalyst/pull/1916) + [(#1921)](https://github.com/PennyLaneAI/catalyst/pull/1921) The LLVM version has been updated to [commit f8cb798](https://github.com/llvm/llvm-project/tree/f8cb7987c64dcffb72414a40560055cb717dbf74). - The mlir-hlo version has been updated to - [commit 1dd2e71](https://github.com/tensorflow/mlir-hlo/tree/1dd2e71331014ae0373f6bf900ce6be393357190). + The stablehlo version has been updated to + [commit 69d6dae](https://github.com/openxla/stablehlo/commit/69d6dae46e1c7de36e6e6973654754f05353cba5). The Enzyme version has been updated to [v0.0.186](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186). From fdf96d2c2b06f45dbe466c0b332fa86317bab27f Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:00:14 -0400 Subject: [PATCH 55/63] add SYSTEM to stablehlo includes --- mlir/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 72c03b0191..894ddb8a12 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -80,7 +80,7 @@ foreach(STABLEHLO_LIB IN LISTS STABLEHLO_LIBS) target_link_libraries(ExternalStablehloLib INTERFACE ${STABLEHLO_LIB}) endforeach() -target_include_directories(ExternalStablehloLib INTERFACE +target_include_directories(ExternalStablehloLib SYSTEM INTERFACE ${STABLEHLO_DIR} ${STABLEHLO_BUILD_DIR} # for the generated .inc files ) From 42edcc6599bab7f5dc52fe6abafec7bcde209942 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:12:33 -0400 Subject: [PATCH 56/63] put back enable_lld/zlib in stablehlo build --- mlir/Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/Makefile b/mlir/Makefile index 4c94146664..e3d259af01 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -103,6 +103,8 @@ stablehlo: -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_LLD=$(ENABLE_LLD) \ + -DLLVM_ENABLE_ZLIB=$(ENABLE_ZLIB) \ -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ From 5b99ac5d88e9688bd460d53489e40af902828a45 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:22:07 -0400 Subject: [PATCH 57/63] clean cmake --- mlir/lib/Driver/CMakeLists.txt | 3 +-- mlir/lib/hlo-extensions/CMakeLists.txt | 2 +- mlir/tools/catalyst-cli/CMakeLists.txt | 2 +- mlir/tools/quantum-lsp-server/CMakeLists.txt | 2 +- mlir/tools/quantum-opt/CMakeLists.txt | 2 +- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 43ec38a7c1..c3a82be9ec 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -25,6 +25,7 @@ set(LIBS ${conversion_libs} ${extension_libs} ${translation_libs} + ExternalStablehloLib MLIROptLib MLIRCatalyst catalyst-transforms @@ -52,5 +53,3 @@ add_mlir_library(CatalystCompilerDriver LINK_LIBS PRIVATE ${LIBS} ) - -target_link_libraries(CatalystCompilerDriver PRIVATE ExternalStablehloLib) diff --git a/mlir/lib/hlo-extensions/CMakeLists.txt b/mlir/lib/hlo-extensions/CMakeLists.txt index 54c079d311..de0833b29d 100644 --- a/mlir/lib/hlo-extensions/CMakeLists.txt +++ b/mlir/lib/hlo-extensions/CMakeLists.txt @@ -15,6 +15,7 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + ExternalStablehloLib ) set(DEPENDS @@ -29,4 +30,3 @@ target_include_directories(${LIBRARY_NAME} PUBLIC . ${PROJECT_SOURCE_DIR}/include ${CMAKE_BINARY_DIR}/include) -target_link_libraries(${LIBRARY_NAME} PRIVATE ExternalStablehloLib) diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index 93b204347d..1dd3d9693d 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -23,6 +23,7 @@ set(LIBS ${dialect_libs} ${conversion_libs} ${extension_libs} + ExternalStablehloLib MLIROptLib MLIRCatalyst catalyst-transforms @@ -46,7 +47,6 @@ set(LIBS add_mlir_tool(catalyst-cli catalyst-cli.cpp SUPPORT_PLUGINS) target_link_libraries(catalyst-cli PRIVATE ${LIBS}) -target_link_libraries(catalyst-cli PRIVATE ExternalStablehloLib) llvm_update_compile_flags(catalyst-cli) mlir_check_all_link_libraries(catalyst-cli) export_executable_symbols_for_plugins(catalyst-cli) diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index 8809649d62..f4a7c2e727 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -3,6 +3,7 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + ExternalStablehloLib MLIRLspServerLib MLIRCatalyst MLIRQuantum @@ -15,6 +16,5 @@ set(LIBS add_llvm_executable(quantum-lsp-server quantum-lsp-server.cpp) target_link_libraries(quantum-lsp-server PRIVATE ${LIBS}) -target_link_libraries(quantum-lsp-server PRIVATE ExternalStablehloLib) llvm_update_compile_flags(quantum-lsp-server) mlir_check_all_link_libraries(quantum-lsp-server) diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 922d4b71da..10c6ed5a0f 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -5,6 +5,7 @@ set(LIBS ${dialect_libs} ${conversion_libs} ${extension_libs} + ExternalStablehloLib MLIROptLib MLIRCatalyst catalyst-transforms @@ -28,7 +29,6 @@ set(LIBS add_mlir_tool(quantum-opt quantum-opt.cpp DEPENDS ${LIBS} SUPPORT_PLUGINS) target_link_libraries(quantum-opt PRIVATE ${LIBS}) -target_link_libraries(quantum-opt PRIVATE ExternalStablehloLib) llvm_update_compile_flags(quantum-opt) mlir_check_all_link_libraries(quantum-opt) export_executable_symbols_for_plugins(quantum-opt) From bc1d58ccb0ad0825155763592a2f2c8166a779f6 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:29:05 -0400 Subject: [PATCH 58/63] track enable_lld --- mlir/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/Makefile b/mlir/Makefile index e3d259af01..8fc76e11e7 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -99,7 +99,7 @@ stablehlo: @echo "build stablehlo" cmake -G Ninja -S stablehlo -B $(STABLEHLO_BUILD_DIR) \ - -DSTABLEHLO_ENABLE_LLD=ON \ + -DSTABLEHLO_ENABLE_LLD=$(ENABLE_LLD) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DMLIR_DIR=$(LLVM_BUILD_DIR)/lib/cmake/mlir \ -DLLVM_ENABLE_ASSERTIONS=ON \ From a035970a95fc2c05c8efb35a2150760a13edd911 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:30:53 -0400 Subject: [PATCH 59/63] CI From 004e502ab324aff3629f7b771e02dfb60f6d289e Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:37:03 -0400 Subject: [PATCH 60/63] try `make stablehlo` in wheels --- .github/workflows/build-wheel-linux-arm64.yaml | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 097fa4dd56..d05d8f4f7e 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -203,18 +203,12 @@ jobs: - name: Build Stablehlo Dialect if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - cmake -S $GITHUB_WORKSPACE/mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DSTABLEHLO_ENABLE_LLD=ON \ - -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ - -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - -DCMAKE_CXX_VISIBILITY_PRESET=default - - cmake --build $GITHUB_WORKSPACE/stablehlo-build + C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ + CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ + COMPILER_LAUNCHER="" \ + make stablehlo - name: Save Stablehlo Build id: save-stablehlo-build From 65dbfca63b48eec8682eae920b86acdaa0f5319a Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:50:04 -0400 Subject: [PATCH 61/63] wheels --- .github/workflows/build-wheel-linux-arm64.yaml | 4 ++-- .../workflows/build-wheel-linux-x86_64.yaml | 18 ++++++------------ .github/workflows/build-wheel-macos-arm64.yaml | 16 ++++------------ 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index d05d8f4f7e..257eb59cf3 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -203,8 +203,8 @@ jobs: - name: Build Stablehlo Dialect if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | - C_COMPILER=$(which ${{ needs.constants.outputs[format('c_compiler.{0}', matrix.compiler)] }}) \ - CXX_COMPILER=$(which ${{ needs.constants.outputs[format('cxx_compiler.{0}', matrix.compiler)] }}) \ + C_COMPILER=$(which gcc) \ + CXX_COMPILER=$(which g++) \ LLVM_BUILD_DIR="$(pwd)/llvm-build" \ STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ COMPILER_LAUNCHER="" \ diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 352066d6e5..c67576d137 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -225,18 +225,12 @@ jobs: - name: Build Stablehlo Dialect if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - cmake -S $GITHUB_WORKSPACE/mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DSTABLEHLO_ENABLE_LLD=ON \ - -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ - -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - -DCMAKE_CXX_VISIBILITY_PRESET=default - - cmake --build $GITHUB_WORKSPACE/stablehlo-build + C_COMPILER=$(which gcc) \ + CXX_COMPILER=$(which g++) \ + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ + COMPILER_LAUNCHER="" \ + make stablehlo - name: Save Stablehlo Build id: save-stablehlo-build diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index 07743d018b..3c062f87a1 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -196,18 +196,10 @@ jobs: - name: Build Stablehlo Dialect if: steps.cache-stablehlo-build.outputs.cache-hit != 'true' run: | - export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH - cmake -S $GITHUB_WORKSPACE/mlir/stablehlo -B $GITHUB_WORKSPACE/stablehlo-build -G Ninja \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/mlir" \ - -DSTABLEHLO_ENABLE_LLD=OFF \ - -DSTABLEHLO_ENABLE_BINDINGS_PYTHON=OFF \ - -DSTABLEHLO_ENABLE_SPLIT_DWARF=ON \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - -DCMAKE_CXX_VISIBILITY_PRESET=default - - cmake --build $GITHUB_WORKSPACE/stablehlo-build + LLVM_BUILD_DIR="$(pwd)/llvm-build" \ + STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ + COMPILER_LAUNCHER="" \ + make stablehlo - name: Save Stablehlo Build id: save-stablehlo-build From dfb233c4c795b40da40be4ae944154996f625dc2 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 17:54:42 -0400 Subject: [PATCH 62/63] clear stablehlo CI cache and rerun From ffca62e55e759127aa9e1bc81fb22807cc185a5b Mon Sep 17 00:00:00 2001 From: paul0403 Date: Thu, 14 Aug 2025 18:07:03 -0400 Subject: [PATCH 63/63] no lld for gcc --- .github/workflows/build-wheel-linux-arm64.yaml | 1 + .github/workflows/build-wheel-linux-x86_64.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index 257eb59cf3..b79f6c164c 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -208,6 +208,7 @@ jobs: LLVM_BUILD_DIR="$(pwd)/llvm-build" \ STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ COMPILER_LAUNCHER="" \ + ENABLE_LLD=OFF \ make stablehlo - name: Save Stablehlo Build diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index c67576d137..d6faa14248 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -230,6 +230,7 @@ jobs: LLVM_BUILD_DIR="$(pwd)/llvm-build" \ STABLEHLO_BUILD_DIR="$GITHUB_WORKSPACE/stablehlo-build" \ COMPILER_LAUNCHER="" \ + ENABLE_LLD=OFF \ make stablehlo - name: Save Stablehlo Build