39
39
#include " src/enzyme_ad/jax/Dialect/Dialect.h"
40
40
#include " src/enzyme_ad/jax/Implementations/XLADerivatives.h"
41
41
#include " src/enzyme_ad/jax/Passes/Passes.h"
42
+ #include " src/enzyme_ad/jax/RegistryUtils.h"
42
43
#include " llvm/Support/TargetSelect.h"
43
44
44
45
#include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
45
46
#include " stablehlo/dialect/ChloOps.h"
46
47
#include " stablehlo/dialect/StablehloOps.h"
48
+ #include " stablehlo/transforms/Passes.h"
49
+ #include " stablehlo/transforms/optimization/Passes.h"
47
50
48
51
#include " absl/log/globals.h"
49
52
#include " absl/log/initialize.h"
75
78
#include " xla/pjrt/pjrt_api.h"
76
79
#include " xla/pjrt/pjrt_c_api_client.h"
77
80
#include " xla/pjrt/pjrt_executable.h"
81
+ #include " xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
78
82
79
83
// CPU collectives
80
84
#include " xla/backends/cpu/collectives/mpi_collectives.h"
147
151
148
152
#include " jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
149
153
150
- // Triton did a dumb thing and their import is incompatible
151
- // We don't use so disabling until upstream fix
152
- // #include "triton/Dialect/Triton/IR/Dialect.h"
153
-
154
154
#include " llvm/Support/ExtensibleRTTI.h"
155
155
#include < llvm/Support/FileSystem.h>
156
156
#include < llvm/Support/raw_ostream.h>
@@ -164,10 +164,6 @@ void registerRemoveTransformPass();
164
164
void registerGenerateApplyPatternsPass ();
165
165
} // namespace enzyme
166
166
167
- namespace triton {
168
- class TritonDialect ;
169
- }
170
-
171
167
} // namespace mlir
172
168
173
169
namespace reactant {
@@ -213,6 +209,8 @@ using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>;
213
209
using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>;
214
210
using HeldHloModule = HeldValue<std::shared_ptr<xla::HloModule>>;
215
211
using HeldIfrtSharding = HeldValue<std::shared_ptr<xla::ifrt::Sharding>>;
212
+ using HeldIfrtLoadedExecutable =
213
+ HeldValue<std::shared_ptr<xla::ifrt::LoadedExecutable>>;
216
214
217
215
extern " C" void (*ReactantThrowError)(const char *) = nullptr ;
218
216
@@ -409,8 +407,7 @@ PjRtClient *MakeCPUClientInternal(
409
407
if (collectives.has_value ())
410
408
options.collectives = collectives.value ();
411
409
412
- auto client = MyValueOrThrow (GetTfrtCpuClient (options));
413
- return client.release ();
410
+ return MyValueOrThrow (GetPjRtCpuClient (options)).release ();
414
411
}
415
412
416
413
extern " C" PjRtClient *MakeCPUClient (uint8_t asynchronous, int node_id) {
@@ -1118,25 +1115,15 @@ extern "C" int PjRtLoadedExecutableNumPartitions(PjRtLoadedExecutable *exec) {
1118
1115
return exec->num_partitions ();
1119
1116
}
1120
1117
1121
- void prepareRegistry (mlir::DialectRegistry ®istry);
1122
-
1123
1118
extern " C" void RegisterDialects (MlirContext cctx) {
1124
1119
mlir::MLIRContext &context = *unwrap (cctx);
1125
1120
DialectRegistry registry;
1126
- prepareRegistry (registry);
1121
+ mlir::enzyme::prepareRegistry (registry);
1122
+ mlir::enzyme::registerDialects (registry);
1123
+ mlir::enzyme::registerInterfaces (registry);
1124
+
1127
1125
context.appendDialectRegistry (registry);
1128
- context.loadDialect <mlir::arith::ArithDialect>();
1129
- context.loadDialect <mlir::enzyme::EnzymeDialect>();
1130
- context.loadDialect <mlir::enzymexla::EnzymeXLADialect>();
1131
- // context.loadDialect<mlir::triton::TritonDialect>();
1132
- context.loadDialect <mlir::tpu::TPUDialect>();
1133
- context.loadDialect <mlir::tensor::TensorDialect>();
1134
- context.loadDialect <mlir::func::FuncDialect>();
1135
- context.loadDialect <mlir::mhlo::MhloDialect>();
1136
- context.loadDialect <mlir::stablehlo::StablehloDialect>();
1137
- context.loadDialect <mlir::chlo::ChloDialect>();
1138
- context.loadDialect <mlir::sdy::SdyDialect>();
1139
- context.loadDialect <mlir::LLVM::LLVMDialect>();
1126
+ mlir::enzyme::loadAllRegisteredDialects (context);
1140
1127
}
1141
1128
1142
1129
#include " mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
@@ -1145,53 +1132,14 @@ extern "C" void RegisterDialects(MlirContext cctx) {
1145
1132
#include " xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
1146
1133
1147
1134
extern " C" void InitializePasses (MlirDialectRegistry creg) {
1148
- mlir::registerenzymePasses ();
1149
- enzyme::registerenzymexlaPasses ();
1150
-
1151
- // Register the standard passes we want.
1152
- mlir::registerTransformsPasses ();
1153
- mlir::registerLowerAffinePass ();
1154
- mlir::registerSCCPPass ();
1155
- mlir::registerInlinerPass ();
1156
- mlir::registerSymbolDCEPass ();
1157
- mlir::registerLoopInvariantCodeMotionPass ();
1158
- mlir::registerConvertSCFToOpenMPPass ();
1159
- mlir::affine::registerAffinePasses ();
1160
- mlir::registerReconcileUnrealizedCastsPass ();
1161
-
1162
- /*
1163
- registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
1164
- LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
1165
- LLVM::LLVMArrayType::attachInterface<MemRefInsider>(*ctx);
1166
- LLVM::LLVMPointerType::attachInterface<MemRefInsider>(*ctx);
1167
- LLVM::LLVMStructType::attachInterface<MemRefInsider>(*ctx);
1168
- MemRefType::attachInterface<PtrElementModel<MemRefType>>(*ctx);
1169
- LLVM::LLVMStructType::attachInterface<
1170
- PtrElementModel<LLVM::LLVMStructType>>(*ctx);
1171
- LLVM::LLVMPointerType::attachInterface<
1172
- PtrElementModel<LLVM::LLVMPointerType>>(*ctx);
1173
- LLVM::LLVMArrayType::attachInterface<PtrElementModel<LLVM::LLVMArrayType>>(
1174
- *ctx);
1175
- });
1176
- */
1177
-
1178
- // Transform dialect and extensions.
1179
- mlir::transform::registerInterpreterPass ();
1180
- mlir::enzyme::registerGenerateApplyPatternsPass ();
1181
- mlir::enzyme::registerRemoveTransformPass ();
1182
-
1183
- // xla + shardy specific passes
1184
- xla::sdy::registerSdyRoundTripExportPipeline ();
1185
- xla::sdy::registerSdyRoundTripImportPipeline ();
1186
- mlir::sdy::registerAllSdyPassesAndPipelines ();
1187
- xla::sdy::registerStablehloExportPipeline ();
1188
- xla::sdy::registerStablehloImportPipeline ();
1189
- xla::sdy::registerStablehloImportShardingsPass ();
1135
+ mlir::enzyme::initializePasses ();
1190
1136
}
1191
1137
1192
1138
extern " C" void InitializeRegistry (MlirDialectRegistry creg) {
1193
1139
mlir::DialectRegistry ®istry = *unwrap (creg);
1194
- prepareRegistry (registry);
1140
+ mlir::enzyme::prepareRegistry (registry);
1141
+ mlir::enzyme::registerDialects (registry);
1142
+ mlir::enzyme::registerInterfaces (registry);
1195
1143
1196
1144
mlir::registerLLVMDialectImport (registry);
1197
1145
mlir::registerNVVMDialectImport (registry);
@@ -1424,7 +1372,7 @@ ifrt_pjrt_array_create(ifrt::PjRtClient *client,
1424
1372
1425
1373
// we might me interested in the `Compiler::Compile` method variant that accepts
1426
1374
// `Topology`
1427
- extern " C" xla::ifrt::LoadedExecutable *
1375
+ extern " C" HeldIfrtLoadedExecutable *
1428
1376
ifrt_compile (ifrt::Client *client, MlirModule cmod, int64_t device_id,
1429
1377
const int64_t *mesh_ids, int64_t num_mesh_ids,
1430
1378
const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner,
@@ -1452,9 +1400,8 @@ ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id,
1452
1400
std::make_unique<xla::ifrt::HloProgram>(xla::ifrt::HloProgram (cmod_op));
1453
1401
auto compiler = client->GetDefaultCompiler ();
1454
1402
1455
- return MyValueOrThrow (
1456
- compiler->Compile (std::move (program), std::move (options)))
1457
- .release ();
1403
+ return reactant::capture (MyValueOrThrow (
1404
+ compiler->CompileAndLoad (std::move (program), std::move (options))));
1458
1405
}
1459
1406
1460
1407
extern " C" void
@@ -2325,19 +2272,20 @@ extern "C" mlir::sdy::TensorShardingAttr hloShardingToTensorShardingAttr(
2325
2272
2326
2273
return mlir::sdy::TensorShardingAttr::get (
2327
2274
context, meshName, tensorShardingAttr.getDimShardings (),
2328
- tensorShardingAttr.getReplicatedAxes ());
2275
+ tensorShardingAttr.getReplicatedAxes (),
2276
+ tensorShardingAttr.getUnreducedAxes ());
2329
2277
}
2330
2278
2331
2279
#pragma endregion
2332
2280
2333
2281
#pragma region ifrt::LoadedExecutable
2334
2282
2335
- extern " C" void ifrt_loaded_executable_dtor (ifrt::LoadedExecutable *exec) {
2283
+ extern " C" void ifrt_loaded_executable_dtor (HeldIfrtLoadedExecutable *exec) {
2336
2284
delete exec;
2337
2285
}
2338
2286
2339
2287
extern " C" void ifrt_loaded_executable_execute (
2340
- ifrt::LoadedExecutable *exec, int num_args,
2288
+ HeldIfrtLoadedExecutable *exec, int num_args,
2341
2289
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
2342
2290
uint8_t *is_arg_donatable, int num_results,
2343
2291
HeldValue<tsl::RCReference<ifrt::Array>> **op_results, uint8_t *futures,
@@ -2355,7 +2303,7 @@ extern "C" void ifrt_loaded_executable_execute(
2355
2303
}
2356
2304
options.fill_status = true ;
2357
2305
2358
- auto result = MyValueOrThrow (exec->Execute (
2306
+ auto result = MyValueOrThrow (exec->obj ()-> Execute (
2359
2307
static_cast <absl::Span<tsl::RCReference<xla::ifrt::Array>>>(args),
2360
2308
options, /* devices */ std::nullopt));
2361
2309
@@ -2376,16 +2324,16 @@ extern "C" void ifrt_loaded_executable_execute(
2376
2324
}
2377
2325
2378
2326
extern " C" ifrt::Client *
2379
- ifrt_loaded_executable_client (ifrt::LoadedExecutable *exec) {
2380
- return exec->client ();
2327
+ ifrt_loaded_executable_client (HeldIfrtLoadedExecutable *exec) {
2328
+ return exec->obj ()-> client ();
2381
2329
}
2382
2330
2383
2331
extern " C" void
2384
- ifrt_loaded_executable_get_parameter_shardings (ifrt::LoadedExecutable *exec,
2332
+ ifrt_loaded_executable_get_parameter_shardings (HeldIfrtLoadedExecutable *exec,
2385
2333
xla::OpSharding **op_shardings,
2386
2334
int32_t num_op_shardings) {
2387
2335
std::optional<std::vector<xla::OpSharding>> shardings =
2388
- exec->GetParameterShardings ();
2336
+ exec->obj ()-> GetParameterShardings ();
2389
2337
if (!shardings.has_value ()) {
2390
2338
ReactantThrowError (
2391
2339
" No sharding found for the output of the loaded executable" );
@@ -2405,11 +2353,11 @@ ifrt_loaded_executable_get_parameter_shardings(ifrt::LoadedExecutable *exec,
2405
2353
}
2406
2354
2407
2355
extern " C" void
2408
- ifrt_loaded_executable_get_output_shardings (ifrt::LoadedExecutable *exec,
2356
+ ifrt_loaded_executable_get_output_shardings (HeldIfrtLoadedExecutable *exec,
2409
2357
xla::OpSharding **op_shardings,
2410
2358
int32_t num_op_shardings) {
2411
2359
std::optional<std::vector<xla::OpSharding>> shardings =
2412
- exec->GetOutputShardings ();
2360
+ exec->obj ()-> GetOutputShardings ();
2413
2361
if (!shardings.has_value ()) {
2414
2362
ReactantThrowError (
2415
2363
" No sharding found for the output of the loaded executable" );
@@ -2429,18 +2377,18 @@ ifrt_loaded_executable_get_output_shardings(ifrt::LoadedExecutable *exec,
2429
2377
}
2430
2378
2431
2379
extern " C" void
2432
- ifrt_loaded_executable_get_hlo_modules (ifrt::LoadedExecutable *exec,
2380
+ ifrt_loaded_executable_get_hlo_modules (HeldIfrtLoadedExecutable *exec,
2433
2381
void **hlo_modules, int32_t *nmodules) {
2434
- auto hlo_modules_vec = MyValueOrThrow (exec->GetHloModules ());
2382
+ auto hlo_modules_vec = MyValueOrThrow (exec->obj ()-> GetHloModules ());
2435
2383
*nmodules = hlo_modules_vec.size ();
2436
2384
for (int32_t i = 0 ; i < *nmodules; i++) {
2437
2385
hlo_modules[i] = reactant::capture (hlo_modules_vec[i]);
2438
2386
}
2439
2387
}
2440
2388
2441
2389
extern " C" int32_t
2442
- ifrt_loaded_executable_num_devices (ifrt::LoadedExecutable *exec) {
2443
- return static_cast <int32_t >(exec->num_devices ());
2390
+ ifrt_loaded_executable_num_devices (HeldIfrtLoadedExecutable *exec) {
2391
+ return static_cast <int32_t >(exec->obj ()-> num_devices ());
2444
2392
}
2445
2393
2446
2394
#pragma endregion
0 commit comments