Skip to content

Commit f716f6b

Browse files
authored
feat: register various stablehlo passes (#1391)
* feat: register various stablehlo passes * feat: centralize dialect registration in EnzymeJAX * fix: update passes * chore: bump commit * fix: build * fix: linking errors * chore: bump commit
1 parent b25d1d3 commit f716f6b

File tree

3 files changed

+34
-78
lines changed

3 files changed

+34
-78
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,14 @@
3939
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
4040
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
4141
#include "src/enzyme_ad/jax/Passes/Passes.h"
42+
#include "src/enzyme_ad/jax/RegistryUtils.h"
4243
#include "llvm/Support/TargetSelect.h"
4344

4445
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
4546
#include "stablehlo/dialect/ChloOps.h"
4647
#include "stablehlo/dialect/StablehloOps.h"
48+
#include "stablehlo/transforms/Passes.h"
49+
#include "stablehlo/transforms/optimization/Passes.h"
4750

4851
#include "absl/log/globals.h"
4952
#include "absl/log/initialize.h"
@@ -148,10 +151,6 @@
148151

149152
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
150153

151-
// Triton did a dumb thing and their import is incompatible
152-
// We don't use so disabling until upstream fix
153-
// #include "triton/Dialect/Triton/IR/Dialect.h"
154-
155154
#include "llvm/Support/ExtensibleRTTI.h"
156155
#include <llvm/Support/FileSystem.h>
157156
#include <llvm/Support/raw_ostream.h>
@@ -165,10 +164,6 @@ void registerRemoveTransformPass();
165164
void registerGenerateApplyPatternsPass();
166165
} // namespace enzyme
167166

168-
namespace triton {
169-
class TritonDialect;
170-
}
171-
172167
} // namespace mlir
173168

174169
namespace reactant {
@@ -1120,25 +1115,15 @@ extern "C" int PjRtLoadedExecutableNumPartitions(PjRtLoadedExecutable *exec) {
11201115
return exec->num_partitions();
11211116
}
11221117

1123-
void prepareRegistry(mlir::DialectRegistry &registry);
1124-
11251118
extern "C" void RegisterDialects(MlirContext cctx) {
11261119
mlir::MLIRContext &context = *unwrap(cctx);
11271120
DialectRegistry registry;
1128-
prepareRegistry(registry);
1121+
mlir::enzyme::prepareRegistry(registry);
1122+
mlir::enzyme::registerDialects(registry);
1123+
mlir::enzyme::registerInterfaces(registry);
1124+
11291125
context.appendDialectRegistry(registry);
1130-
context.loadDialect<mlir::arith::ArithDialect>();
1131-
context.loadDialect<mlir::enzyme::EnzymeDialect>();
1132-
context.loadDialect<mlir::enzymexla::EnzymeXLADialect>();
1133-
// context.loadDialect<mlir::triton::TritonDialect>();
1134-
context.loadDialect<mlir::tpu::TPUDialect>();
1135-
context.loadDialect<mlir::tensor::TensorDialect>();
1136-
context.loadDialect<mlir::func::FuncDialect>();
1137-
context.loadDialect<mlir::mhlo::MhloDialect>();
1138-
context.loadDialect<mlir::stablehlo::StablehloDialect>();
1139-
context.loadDialect<mlir::chlo::ChloDialect>();
1140-
context.loadDialect<mlir::sdy::SdyDialect>();
1141-
context.loadDialect<mlir::LLVM::LLVMDialect>();
1126+
mlir::enzyme::loadAllRegisteredDialects(context);
11421127
}
11431128

11441129
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
@@ -1147,53 +1132,14 @@ extern "C" void RegisterDialects(MlirContext cctx) {
11471132
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
11481133

11491134
extern "C" void InitializePasses(MlirDialectRegistry creg) {
1150-
mlir::registerenzymePasses();
1151-
enzyme::registerenzymexlaPasses();
1152-
1153-
// Register the standard passes we want.
1154-
mlir::registerTransformsPasses();
1155-
mlir::registerLowerAffinePass();
1156-
mlir::registerSCCPPass();
1157-
mlir::registerInlinerPass();
1158-
mlir::registerSymbolDCEPass();
1159-
mlir::registerLoopInvariantCodeMotionPass();
1160-
mlir::registerConvertSCFToOpenMPPass();
1161-
mlir::affine::registerAffinePasses();
1162-
mlir::registerReconcileUnrealizedCastsPass();
1163-
1164-
/*
1165-
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
1166-
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
1167-
LLVM::LLVMArrayType::attachInterface<MemRefInsider>(*ctx);
1168-
LLVM::LLVMPointerType::attachInterface<MemRefInsider>(*ctx);
1169-
LLVM::LLVMStructType::attachInterface<MemRefInsider>(*ctx);
1170-
MemRefType::attachInterface<PtrElementModel<MemRefType>>(*ctx);
1171-
LLVM::LLVMStructType::attachInterface<
1172-
PtrElementModel<LLVM::LLVMStructType>>(*ctx);
1173-
LLVM::LLVMPointerType::attachInterface<
1174-
PtrElementModel<LLVM::LLVMPointerType>>(*ctx);
1175-
LLVM::LLVMArrayType::attachInterface<PtrElementModel<LLVM::LLVMArrayType>>(
1176-
*ctx);
1177-
});
1178-
*/
1179-
1180-
// Transform dialect and extensions.
1181-
mlir::transform::registerInterpreterPass();
1182-
mlir::enzyme::registerGenerateApplyPatternsPass();
1183-
mlir::enzyme::registerRemoveTransformPass();
1184-
1185-
// xla + shardy specific passes
1186-
xla::sdy::registerSdyRoundTripExportPipeline();
1187-
xla::sdy::registerSdyRoundTripImportPipeline();
1188-
mlir::sdy::registerAllSdyPassesAndPipelines();
1189-
xla::sdy::registerStablehloExportPipeline();
1190-
xla::sdy::registerStablehloImportPipeline();
1191-
xla::sdy::registerStablehloImportShardingsPass();
1135+
mlir::enzyme::initializePasses();
11921136
}
11931137

11941138
extern "C" void InitializeRegistry(MlirDialectRegistry creg) {
11951139
mlir::DialectRegistry &registry = *unwrap(creg);
1196-
prepareRegistry(registry);
1140+
mlir::enzyme::prepareRegistry(registry);
1141+
mlir::enzyme::registerDialects(registry);
1142+
mlir::enzyme::registerInterfaces(registry);
11971143

11981144
mlir::registerLLVMDialectImport(registry);
11991145
mlir::registerNVVMDialectImport(registry);
@@ -2326,7 +2272,8 @@ extern "C" mlir::sdy::TensorShardingAttr hloShardingToTensorShardingAttr(
23262272

23272273
return mlir::sdy::TensorShardingAttr::get(
23282274
context, meshName, tensorShardingAttr.getDimShardings(),
2329-
tensorShardingAttr.getReplicatedAxes(), tensorShardingAttr.getUnreducedAxes());
2275+
tensorShardingAttr.getReplicatedAxes(),
2276+
tensorShardingAttr.getUnreducedAxes());
23302277
}
23312278

23322279
#pragma endregion

deps/ReactantExtra/BUILD

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -761,25 +761,18 @@ cc_library(
761761
],
762762

763763
) + [
764-
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
765764
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
766765
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
767-
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
768-
# "@xla//xla:xla.pb.cc",
769-
"@xla//xla:xla_data.pb.cc",
770-
# "@xla//xla/stream_executor:device_description.pb.cc",
771-
"@xla//xla/service:hlo.pb.cc",
772-
# # "@tsl//tsl/protobuf:dnn.pb.cc",
773-
#"@tsl//tsl/protobuf:histogram.pb.cc",
774-
#"@tsl//tsl/protobuf:bfc_memory_map_proto",bfc_memory_map.pb.cc",
775766
"@xla//xla/service/gpu:backend_configs.pb.cc",
776767
"@xla//xla:autotuning.pb.cc",
777768
"@xla//xla:autotune_results.pb.cc",
778769
"@xla//xla/service:buffer_assignment.pb.cc",
779770
],
780771
hdrs = glob([
781772
"*.h",
782-
]),
773+
]) + [
774+
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.h",
775+
],
783776
copts = [
784777
"-Werror=unused-variable",
785778
"-Werror=unused-but-set-variable",
@@ -901,6 +894,7 @@ cc_library(
901894
"-Wl,-exported_symbol,_addSdyPropagationPipeline",
902895
]}),
903896
deps = [
897+
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils",
904898
"@enzyme//:EnzymeMLIR",
905899
"@llvm-project//mlir:AffineDialect",
906900
"@llvm-project//mlir:AllPassesAndDialects",
@@ -1031,6 +1025,21 @@ cc_library(
10311025
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
10321026
"@xla//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client",
10331027
"@xla//xla/pjrt/plugin/xla_tpu:xla_tpu_pjrt_client",
1028+
"@stablehlo//:linalg_passes",
1029+
"@stablehlo//:tosa_passes",
1030+
"@stablehlo//:stablehlo_passes",
1031+
"@stablehlo//:stablehlo_passes_optimization",
1032+
"@stablehlo//stablehlo/tests:check_ops",
1033+
1034+
"@tsl//tsl/platform:env",
1035+
"@xla//xla/tsl/protobuf:dnn_proto_cc_impl",
1036+
"@xla//xla/tsl/protobuf:histogram_proto_cc",
1037+
"@xla//xla/tsl/protobuf:histogram_proto_cc_impl",
1038+
"@xla//xla:xla_data_proto_cc_impl",
1039+
"@xla//xla/tsl/platform:env",
1040+
"@xla//xla/tsl/platform:errors",
1041+
"@xla//xla/service:hlo_proto_cc_impl",
1042+
"@com_google_absl//absl/status:statusor",
10341043
] + select({
10351044
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
10361045
"@xla//xla/stream_executor:cuda_platform",

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "bc7a3e5be1e08984329bd16753b07346eda07996"
12+
ENZYMEXLA_COMMIT = "4a19bc9910f34f4afcd2c2559c8776f7ca86ae4d"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

0 commit comments

Comments
 (0)