Skip to content

Commit ebeceb8

Browse files
committed
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into probprog
2 parents a344726 + d8ffd0d commit ebeceb8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2588
-589
lines changed

.github/workflows/CI-localjll.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
with:
4848
version: ${{ matrix.version }}
4949
- uses: julia-actions/cache@v2
50-
- uses: bazel-contrib/setup-bazel@0.14.0
50+
- uses: bazel-contrib/setup-bazel@0.15.0
5151
name: Set up Bazel
5252
with:
5353
# Avoid downloading Bazel every time.

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.122"
4+
version = "0.2.134"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -52,7 +52,7 @@ ReactantArrayInterfaceExt = "ArrayInterface"
5252
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
5353
ReactantKernelAbstractionsExt = "KernelAbstractions"
5454
ReactantMPIExt = "MPI"
55-
ReactantNNlibExt = "NNlib"
55+
ReactantNNlibExt = ["NNlib", "Statistics"]
5656
ReactantOffsetArraysExt = "OffsetArrays"
5757
ReactantOneHotArraysExt = "OneHotArrays"
5858
ReactantPythonCallExt = "PythonCall"
@@ -69,8 +69,8 @@ CEnum = "0.5"
6969
CUDA = "5.6"
7070
Downloads = "1.6"
7171
EnumX = "1"
72-
Enzyme = "0.13.46"
73-
EnzymeCore = "0.8.9"
72+
Enzyme = "0.13.49"
73+
EnzymeCore = "0.8.11"
7474
Functors = "0.5"
7575
GPUArraysCore = "0.2"
7676
GPUCompiler = "1.3"
@@ -89,8 +89,8 @@ Preferences = "1.4"
8989
PythonCall = "0.9"
9090
Random = "1.10"
9191
Random123 = "1.7"
92-
ReactantCore = "0.1.11"
93-
Reactant_jll = "0.0.191"
92+
ReactantCore = "0.1.12"
93+
Reactant_jll = "0.0.201"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"

deps/ReactantExtra/.bazelrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.1"
2323
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
2424
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
2525
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
26-
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
26+
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90"
2727
build:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain"
2828
build:cuda --@local_config_cuda//:enable_cuda
2929
# Default hermetic CUDA and CUDNN versions.

deps/ReactantExtra/API.cpp

Lines changed: 34 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,14 @@
3939
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
4040
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
4141
#include "src/enzyme_ad/jax/Passes/Passes.h"
42+
#include "src/enzyme_ad/jax/RegistryUtils.h"
4243
#include "llvm/Support/TargetSelect.h"
4344

4445
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
4546
#include "stablehlo/dialect/ChloOps.h"
4647
#include "stablehlo/dialect/StablehloOps.h"
48+
#include "stablehlo/transforms/Passes.h"
49+
#include "stablehlo/transforms/optimization/Passes.h"
4750

4851
#include "absl/log/globals.h"
4952
#include "absl/log/initialize.h"
@@ -75,6 +78,7 @@
7578
#include "xla/pjrt/pjrt_api.h"
7679
#include "xla/pjrt/pjrt_c_api_client.h"
7780
#include "xla/pjrt/pjrt_executable.h"
81+
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
7882

7983
// CPU collectives
8084
#include "xla/backends/cpu/collectives/mpi_collectives.h"
@@ -147,10 +151,6 @@
147151

148152
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
149153

150-
// Triton did a dumb thing and their import is incompatible
151-
// We don't use so disabling until upstream fix
152-
// #include "triton/Dialect/Triton/IR/Dialect.h"
153-
154154
#include "llvm/Support/ExtensibleRTTI.h"
155155
#include <llvm/Support/FileSystem.h>
156156
#include <llvm/Support/raw_ostream.h>
@@ -164,10 +164,6 @@ void registerRemoveTransformPass();
164164
void registerGenerateApplyPatternsPass();
165165
} // namespace enzyme
166166

167-
namespace triton {
168-
class TritonDialect;
169-
}
170-
171167
} // namespace mlir
172168

173169
namespace reactant {
@@ -213,6 +209,8 @@ using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>;
213209
using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>;
214210
using HeldHloModule = HeldValue<std::shared_ptr<xla::HloModule>>;
215211
using HeldIfrtSharding = HeldValue<std::shared_ptr<xla::ifrt::Sharding>>;
212+
using HeldIfrtLoadedExecutable =
213+
HeldValue<std::shared_ptr<xla::ifrt::LoadedExecutable>>;
216214

217215
extern "C" void (*ReactantThrowError)(const char *) = nullptr;
218216

@@ -409,8 +407,7 @@ PjRtClient *MakeCPUClientInternal(
409407
if (collectives.has_value())
410408
options.collectives = collectives.value();
411409

412-
auto client = MyValueOrThrow(GetTfrtCpuClient(options));
413-
return client.release();
410+
return MyValueOrThrow(GetPjRtCpuClient(options)).release();
414411
}
415412

416413
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
@@ -1118,25 +1115,15 @@ extern "C" int PjRtLoadedExecutableNumPartitions(PjRtLoadedExecutable *exec) {
11181115
return exec->num_partitions();
11191116
}
11201117

1121-
void prepareRegistry(mlir::DialectRegistry &registry);
1122-
11231118
extern "C" void RegisterDialects(MlirContext cctx) {
11241119
mlir::MLIRContext &context = *unwrap(cctx);
11251120
DialectRegistry registry;
1126-
prepareRegistry(registry);
1121+
mlir::enzyme::prepareRegistry(registry);
1122+
mlir::enzyme::registerDialects(registry);
1123+
mlir::enzyme::registerInterfaces(registry);
1124+
11271125
context.appendDialectRegistry(registry);
1128-
context.loadDialect<mlir::arith::ArithDialect>();
1129-
context.loadDialect<mlir::enzyme::EnzymeDialect>();
1130-
context.loadDialect<mlir::enzymexla::EnzymeXLADialect>();
1131-
// context.loadDialect<mlir::triton::TritonDialect>();
1132-
context.loadDialect<mlir::tpu::TPUDialect>();
1133-
context.loadDialect<mlir::tensor::TensorDialect>();
1134-
context.loadDialect<mlir::func::FuncDialect>();
1135-
context.loadDialect<mlir::mhlo::MhloDialect>();
1136-
context.loadDialect<mlir::stablehlo::StablehloDialect>();
1137-
context.loadDialect<mlir::chlo::ChloDialect>();
1138-
context.loadDialect<mlir::sdy::SdyDialect>();
1139-
context.loadDialect<mlir::LLVM::LLVMDialect>();
1126+
mlir::enzyme::loadAllRegisteredDialects(context);
11401127
}
11411128

11421129
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
@@ -1145,53 +1132,14 @@ extern "C" void RegisterDialects(MlirContext cctx) {
11451132
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
11461133

11471134
extern "C" void InitializePasses(MlirDialectRegistry creg) {
1148-
mlir::registerenzymePasses();
1149-
enzyme::registerenzymexlaPasses();
1150-
1151-
// Register the standard passes we want.
1152-
mlir::registerTransformsPasses();
1153-
mlir::registerLowerAffinePass();
1154-
mlir::registerSCCPPass();
1155-
mlir::registerInlinerPass();
1156-
mlir::registerSymbolDCEPass();
1157-
mlir::registerLoopInvariantCodeMotionPass();
1158-
mlir::registerConvertSCFToOpenMPPass();
1159-
mlir::affine::registerAffinePasses();
1160-
mlir::registerReconcileUnrealizedCastsPass();
1161-
1162-
/*
1163-
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
1164-
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
1165-
LLVM::LLVMArrayType::attachInterface<MemRefInsider>(*ctx);
1166-
LLVM::LLVMPointerType::attachInterface<MemRefInsider>(*ctx);
1167-
LLVM::LLVMStructType::attachInterface<MemRefInsider>(*ctx);
1168-
MemRefType::attachInterface<PtrElementModel<MemRefType>>(*ctx);
1169-
LLVM::LLVMStructType::attachInterface<
1170-
PtrElementModel<LLVM::LLVMStructType>>(*ctx);
1171-
LLVM::LLVMPointerType::attachInterface<
1172-
PtrElementModel<LLVM::LLVMPointerType>>(*ctx);
1173-
LLVM::LLVMArrayType::attachInterface<PtrElementModel<LLVM::LLVMArrayType>>(
1174-
*ctx);
1175-
});
1176-
*/
1177-
1178-
// Transform dialect and extensions.
1179-
mlir::transform::registerInterpreterPass();
1180-
mlir::enzyme::registerGenerateApplyPatternsPass();
1181-
mlir::enzyme::registerRemoveTransformPass();
1182-
1183-
// xla + shardy specific passes
1184-
xla::sdy::registerSdyRoundTripExportPipeline();
1185-
xla::sdy::registerSdyRoundTripImportPipeline();
1186-
mlir::sdy::registerAllSdyPassesAndPipelines();
1187-
xla::sdy::registerStablehloExportPipeline();
1188-
xla::sdy::registerStablehloImportPipeline();
1189-
xla::sdy::registerStablehloImportShardingsPass();
1135+
mlir::enzyme::initializePasses();
11901136
}
11911137

11921138
extern "C" void InitializeRegistry(MlirDialectRegistry creg) {
11931139
mlir::DialectRegistry &registry = *unwrap(creg);
1194-
prepareRegistry(registry);
1140+
mlir::enzyme::prepareRegistry(registry);
1141+
mlir::enzyme::registerDialects(registry);
1142+
mlir::enzyme::registerInterfaces(registry);
11951143

11961144
mlir::registerLLVMDialectImport(registry);
11971145
mlir::registerNVVMDialectImport(registry);
@@ -1424,7 +1372,7 @@ ifrt_pjrt_array_create(ifrt::PjRtClient *client,
14241372

14251373
// we might me interested in the `Compiler::Compile` method variant that accepts
14261374
// `Topology`
1427-
extern "C" xla::ifrt::LoadedExecutable *
1375+
extern "C" HeldIfrtLoadedExecutable *
14281376
ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id,
14291377
const int64_t *mesh_ids, int64_t num_mesh_ids,
14301378
const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner,
@@ -1452,9 +1400,8 @@ ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id,
14521400
std::make_unique<xla::ifrt::HloProgram>(xla::ifrt::HloProgram(cmod_op));
14531401
auto compiler = client->GetDefaultCompiler();
14541402

1455-
return MyValueOrThrow(
1456-
compiler->Compile(std::move(program), std::move(options)))
1457-
.release();
1403+
return reactant::capture(MyValueOrThrow(
1404+
compiler->CompileAndLoad(std::move(program), std::move(options))));
14581405
}
14591406

14601407
extern "C" void
@@ -2325,19 +2272,20 @@ extern "C" mlir::sdy::TensorShardingAttr hloShardingToTensorShardingAttr(
23252272

23262273
return mlir::sdy::TensorShardingAttr::get(
23272274
context, meshName, tensorShardingAttr.getDimShardings(),
2328-
tensorShardingAttr.getReplicatedAxes());
2275+
tensorShardingAttr.getReplicatedAxes(),
2276+
tensorShardingAttr.getUnreducedAxes());
23292277
}
23302278

23312279
#pragma endregion
23322280

23332281
#pragma region ifrt::LoadedExecutable
23342282

2335-
extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) {
2283+
extern "C" void ifrt_loaded_executable_dtor(HeldIfrtLoadedExecutable *exec) {
23362284
delete exec;
23372285
}
23382286

23392287
extern "C" void ifrt_loaded_executable_execute(
2340-
ifrt::LoadedExecutable *exec, int num_args,
2288+
HeldIfrtLoadedExecutable *exec, int num_args,
23412289
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
23422290
uint8_t *is_arg_donatable, int num_results,
23432291
HeldValue<tsl::RCReference<ifrt::Array>> **op_results, uint8_t *futures,
@@ -2355,7 +2303,7 @@ extern "C" void ifrt_loaded_executable_execute(
23552303
}
23562304
options.fill_status = true;
23572305

2358-
auto result = MyValueOrThrow(exec->Execute(
2306+
auto result = MyValueOrThrow(exec->obj()->Execute(
23592307
static_cast<absl::Span<tsl::RCReference<xla::ifrt::Array>>>(args),
23602308
options, /* devices */ std::nullopt));
23612309

@@ -2376,16 +2324,16 @@ extern "C" void ifrt_loaded_executable_execute(
23762324
}
23772325

23782326
extern "C" ifrt::Client *
2379-
ifrt_loaded_executable_client(ifrt::LoadedExecutable *exec) {
2380-
return exec->client();
2327+
ifrt_loaded_executable_client(HeldIfrtLoadedExecutable *exec) {
2328+
return exec->obj()->client();
23812329
}
23822330

23832331
extern "C" void
2384-
ifrt_loaded_executable_get_parameter_shardings(ifrt::LoadedExecutable *exec,
2332+
ifrt_loaded_executable_get_parameter_shardings(HeldIfrtLoadedExecutable *exec,
23852333
xla::OpSharding **op_shardings,
23862334
int32_t num_op_shardings) {
23872335
std::optional<std::vector<xla::OpSharding>> shardings =
2388-
exec->GetParameterShardings();
2336+
exec->obj()->GetParameterShardings();
23892337
if (!shardings.has_value()) {
23902338
ReactantThrowError(
23912339
"No sharding found for the output of the loaded executable");
@@ -2405,11 +2353,11 @@ ifrt_loaded_executable_get_parameter_shardings(ifrt::LoadedExecutable *exec,
24052353
}
24062354

24072355
extern "C" void
2408-
ifrt_loaded_executable_get_output_shardings(ifrt::LoadedExecutable *exec,
2356+
ifrt_loaded_executable_get_output_shardings(HeldIfrtLoadedExecutable *exec,
24092357
xla::OpSharding **op_shardings,
24102358
int32_t num_op_shardings) {
24112359
std::optional<std::vector<xla::OpSharding>> shardings =
2412-
exec->GetOutputShardings();
2360+
exec->obj()->GetOutputShardings();
24132361
if (!shardings.has_value()) {
24142362
ReactantThrowError(
24152363
"No sharding found for the output of the loaded executable");
@@ -2429,18 +2377,18 @@ ifrt_loaded_executable_get_output_shardings(ifrt::LoadedExecutable *exec,
24292377
}
24302378

24312379
extern "C" void
2432-
ifrt_loaded_executable_get_hlo_modules(ifrt::LoadedExecutable *exec,
2380+
ifrt_loaded_executable_get_hlo_modules(HeldIfrtLoadedExecutable *exec,
24332381
void **hlo_modules, int32_t *nmodules) {
2434-
auto hlo_modules_vec = MyValueOrThrow(exec->GetHloModules());
2382+
auto hlo_modules_vec = MyValueOrThrow(exec->obj()->GetHloModules());
24352383
*nmodules = hlo_modules_vec.size();
24362384
for (int32_t i = 0; i < *nmodules; i++) {
24372385
hlo_modules[i] = reactant::capture(hlo_modules_vec[i]);
24382386
}
24392387
}
24402388

24412389
extern "C" int32_t
2442-
ifrt_loaded_executable_num_devices(ifrt::LoadedExecutable *exec) {
2443-
return static_cast<int32_t>(exec->num_devices());
2390+
ifrt_loaded_executable_num_devices(HeldIfrtLoadedExecutable *exec) {
2391+
return static_cast<int32_t>(exec->obj()->num_devices());
24442392
}
24452393

24462394
#pragma endregion

deps/ReactantExtra/BUILD

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -761,25 +761,18 @@ cc_library(
761761
],
762762

763763
) + [
764-
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
765764
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
766765
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
767-
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
768-
# "@xla//xla:xla.pb.cc",
769-
"@xla//xla:xla_data.pb.cc",
770-
# "@xla//xla/stream_executor:device_description.pb.cc",
771-
"@xla//xla/service:hlo.pb.cc",
772-
# # "@tsl//tsl/protobuf:dnn.pb.cc",
773-
#"@tsl//tsl/protobuf:histogram.pb.cc",
774-
#"@tsl//tsl/protobuf:bfc_memory_map_proto",bfc_memory_map.pb.cc",
775766
"@xla//xla/service/gpu:backend_configs.pb.cc",
776767
"@xla//xla:autotuning.pb.cc",
777768
"@xla//xla:autotune_results.pb.cc",
778769
"@xla//xla/service:buffer_assignment.pb.cc",
779770
],
780771
hdrs = glob([
781772
"*.h",
782-
]),
773+
]) + [
774+
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.h",
775+
],
783776
copts = [
784777
"-Werror=unused-variable",
785778
"-Werror=unused-but-set-variable",
@@ -901,6 +894,7 @@ cc_library(
901894
"-Wl,-exported_symbol,_addSdyPropagationPipeline",
902895
]}),
903896
deps = [
897+
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils",
904898
"@enzyme//:EnzymeMLIR",
905899
"@llvm-project//mlir:AffineDialect",
906900
"@llvm-project//mlir:AllPassesAndDialects",
@@ -1028,6 +1022,24 @@ cc_library(
10281022
"@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
10291023
"@xla//xla/service:gpu_plugin",
10301024
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
1025+
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
1026+
"@xla//xla/pjrt/plugin/xla_gpu:xla_gpu_pjrt_client",
1027+
"@xla//xla/pjrt/plugin/xla_tpu:xla_tpu_pjrt_client",
1028+
"@stablehlo//:linalg_passes",
1029+
"@stablehlo//:tosa_passes",
1030+
"@stablehlo//:stablehlo_passes",
1031+
"@stablehlo//:stablehlo_passes_optimization",
1032+
"@stablehlo//stablehlo/tests:check_ops",
1033+
1034+
"@tsl//tsl/platform:env",
1035+
"@xla//xla/tsl/protobuf:dnn_proto_cc_impl",
1036+
"@xla//xla/tsl/protobuf:histogram_proto_cc",
1037+
"@xla//xla/tsl/protobuf:histogram_proto_cc_impl",
1038+
"@xla//xla:xla_data_proto_cc_impl",
1039+
"@xla//xla/tsl/platform:env",
1040+
"@xla//xla/tsl/platform:errors",
1041+
"@xla//xla/service:hlo_proto_cc_impl",
1042+
"@com_google_absl//absl/status:statusor",
10311043
] + select({
10321044
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
10331045
"@xla//xla/stream_executor:cuda_platform",

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "36230482acbbf9b32cca2e3906b7f08f4acc1e45"
12+
ENZYMEXLA_COMMIT = "7872c7761de86316c45a4df72210b1af38e09f07"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

0 commit comments

Comments
 (0)