39
39
#include " src/enzyme_ad/jax/Dialect/Dialect.h"
40
40
#include " src/enzyme_ad/jax/Implementations/XLADerivatives.h"
41
41
#include " src/enzyme_ad/jax/Passes/Passes.h"
42
+ #include " src/enzyme_ad/jax/RegistryUtils.h"
42
43
#include " llvm/Support/TargetSelect.h"
43
44
44
45
#include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
45
46
#include " stablehlo/dialect/ChloOps.h"
46
47
#include " stablehlo/dialect/StablehloOps.h"
48
+ #include " stablehlo/transforms/Passes.h"
49
+ #include " stablehlo/transforms/optimization/Passes.h"
47
50
48
51
#include " absl/log/globals.h"
49
52
#include " absl/log/initialize.h"
148
151
149
152
#include " jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
150
153
151
- // Triton did a dumb thing and their import is incompatible
152
- // We don't use so disabling until upstream fix
153
- // #include "triton/Dialect/Triton/IR/Dialect.h"
154
-
155
154
#include " llvm/Support/ExtensibleRTTI.h"
156
155
#include < llvm/Support/FileSystem.h>
157
156
#include < llvm/Support/raw_ostream.h>
@@ -165,10 +164,6 @@ void registerRemoveTransformPass();
165
164
void registerGenerateApplyPatternsPass ();
166
165
} // namespace enzyme
167
166
168
- namespace triton {
169
- class TritonDialect ;
170
- }
171
-
172
167
} // namespace mlir
173
168
174
169
namespace reactant {
@@ -1120,25 +1115,15 @@ extern "C" int PjRtLoadedExecutableNumPartitions(PjRtLoadedExecutable *exec) {
1120
1115
return exec->num_partitions ();
1121
1116
}
1122
1117
1123
- void prepareRegistry (mlir::DialectRegistry ®istry);
1124
-
1125
1118
extern " C" void RegisterDialects (MlirContext cctx) {
1126
1119
mlir::MLIRContext &context = *unwrap (cctx);
1127
1120
DialectRegistry registry;
1128
- prepareRegistry (registry);
1121
+ mlir::enzyme::prepareRegistry (registry);
1122
+ mlir::enzyme::registerDialects (registry);
1123
+ mlir::enzyme::registerInterfaces (registry);
1124
+
1129
1125
context.appendDialectRegistry (registry);
1130
- context.loadDialect <mlir::arith::ArithDialect>();
1131
- context.loadDialect <mlir::enzyme::EnzymeDialect>();
1132
- context.loadDialect <mlir::enzymexla::EnzymeXLADialect>();
1133
- // context.loadDialect<mlir::triton::TritonDialect>();
1134
- context.loadDialect <mlir::tpu::TPUDialect>();
1135
- context.loadDialect <mlir::tensor::TensorDialect>();
1136
- context.loadDialect <mlir::func::FuncDialect>();
1137
- context.loadDialect <mlir::mhlo::MhloDialect>();
1138
- context.loadDialect <mlir::stablehlo::StablehloDialect>();
1139
- context.loadDialect <mlir::chlo::ChloDialect>();
1140
- context.loadDialect <mlir::sdy::SdyDialect>();
1141
- context.loadDialect <mlir::LLVM::LLVMDialect>();
1126
+ mlir::enzyme::loadAllRegisteredDialects (context);
1142
1127
}
1143
1128
1144
1129
#include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
@@ -1147,53 +1132,14 @@ extern "C" void RegisterDialects(MlirContext cctx) {
1147
1132
#include " xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
1148
1133
1149
1134
extern " C" void InitializePasses (MlirDialectRegistry creg) {
1150
- mlir::registerenzymePasses ();
1151
- enzyme::registerenzymexlaPasses ();
1152
-
1153
- // Register the standard passes we want.
1154
- mlir::registerTransformsPasses ();
1155
- mlir::registerLowerAffinePass ();
1156
- mlir::registerSCCPPass ();
1157
- mlir::registerInlinerPass ();
1158
- mlir::registerSymbolDCEPass ();
1159
- mlir::registerLoopInvariantCodeMotionPass ();
1160
- mlir::registerConvertSCFToOpenMPPass ();
1161
- mlir::affine::registerAffinePasses ();
1162
- mlir::registerReconcileUnrealizedCastsPass ();
1163
-
1164
- /*
1165
- registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
1166
- LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
1167
- LLVM::LLVMArrayType::attachInterface<MemRefInsider>(*ctx);
1168
- LLVM::LLVMPointerType::attachInterface<MemRefInsider>(*ctx);
1169
- LLVM::LLVMStructType::attachInterface<MemRefInsider>(*ctx);
1170
- MemRefType::attachInterface<PtrElementModel<MemRefType>>(*ctx);
1171
- LLVM::LLVMStructType::attachInterface<
1172
- PtrElementModel<LLVM::LLVMStructType>>(*ctx);
1173
- LLVM::LLVMPointerType::attachInterface<
1174
- PtrElementModel<LLVM::LLVMPointerType>>(*ctx);
1175
- LLVM::LLVMArrayType::attachInterface<PtrElementModel<LLVM::LLVMArrayType>>(
1176
- *ctx);
1177
- });
1178
- */
1179
-
1180
- // Transform dialect and extensions.
1181
- mlir::transform::registerInterpreterPass ();
1182
- mlir::enzyme::registerGenerateApplyPatternsPass ();
1183
- mlir::enzyme::registerRemoveTransformPass ();
1184
-
1185
- // xla + shardy specific passes
1186
- xla::sdy::registerSdyRoundTripExportPipeline ();
1187
- xla::sdy::registerSdyRoundTripImportPipeline ();
1188
- mlir::sdy::registerAllSdyPassesAndPipelines ();
1189
- xla::sdy::registerStablehloExportPipeline ();
1190
- xla::sdy::registerStablehloImportPipeline ();
1191
- xla::sdy::registerStablehloImportShardingsPass ();
1135
+ mlir::enzyme::initializePasses ();
1192
1136
}
1193
1137
1194
1138
extern " C" void InitializeRegistry (MlirDialectRegistry creg) {
1195
1139
mlir::DialectRegistry ®istry = *unwrap (creg);
1196
- prepareRegistry (registry);
1140
+ mlir::enzyme::prepareRegistry (registry);
1141
+ mlir::enzyme::registerDialects (registry);
1142
+ mlir::enzyme::registerInterfaces (registry);
1197
1143
1198
1144
mlir::registerLLVMDialectImport (registry);
1199
1145
mlir::registerNVVMDialectImport (registry);
@@ -2326,7 +2272,8 @@ extern "C" mlir::sdy::TensorShardingAttr hloShardingToTensorShardingAttr(
2326
2272
2327
2273
return mlir::sdy::TensorShardingAttr::get (
2328
2274
context, meshName, tensorShardingAttr.getDimShardings (),
2329
- tensorShardingAttr.getReplicatedAxes (), tensorShardingAttr.getUnreducedAxes ());
2275
+ tensorShardingAttr.getReplicatedAxes (),
2276
+ tensorShardingAttr.getUnreducedAxes ());
2330
2277
}
2331
2278
2332
2279
#pragma endregion
0 commit comments