Skip to content

Commit 1a23c2e

Browse files
committed
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into probprog-trace-operand
2 parents 0b71444 + da9c8a3 commit 1a23c2e

33 files changed

+977
-242
lines changed

.github/workflows/CI.yml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ jobs:
3333
timeout-minutes: 90
3434
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.runtime }} - assertions=${{ matrix.assertions }} - ${{ github.event_name }}
3535
runs-on: ${{ matrix.os }}
36+
container:
37+
image: ${{ contains(matrix.os, 'linux') && 'ghcr.io/enzymead/reactant-docker-images:main' || '' }}
3638
strategy:
3739
fail-fast: false
3840
matrix:
@@ -59,6 +61,11 @@ jobs:
5961
assertions:
6062
- false
6163
include:
64+
- os: linux-x86-ct6e-180-4tpu
65+
version: "1.11"
66+
assertions: false
67+
test_group: core
68+
runtime: "IFRT"
6269
- os: ubuntu-24.04
6370
version: "1.10"
6471
assertions: true
@@ -86,9 +93,13 @@ jobs:
8693
# libReactant: packaged
8794
# version: '1.10'
8895
# test_group: integration
89-
env:
90-
TMPDIR: ${{ github.workspace }}/tmp
9196
steps:
97+
- name: Set TMPDIR
98+
# We have to use `${GITHUB_WORKSPACE}` instead of `github.workspace` because GitHub
99+
# is terrible and the two don't match inside containers:
100+
# https://github.com/actions/runner/issues/2058
101+
run: |
102+
echo "TMPDIR=${GITHUB_WORKSPACE}/tmp" >> ${GITHUB_ENV}
92103
- uses: actions/checkout@v4
93104
- name: Create TMPDIR
94105
run: |

Project.toml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.146"
4+
version = "0.2.149"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -25,11 +25,14 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
2525
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
2626
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
2727
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
28+
unzip_jll = "88f77b66-78eb-5ed0-bc16-ebba0796830d"
2829

2930
[weakdeps]
3031
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3132
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3233
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34+
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
35+
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
3336
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3437
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3538
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
@@ -43,13 +46,15 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4346
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4447
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
4548

46-
[sources.ReactantCore]
47-
path = "lib/ReactantCore"
49+
[sources]
50+
ReactantCore = {path = "lib/ReactantCore"}
4851

4952
[extensions]
5053
ReactantAbstractFFTsExt = "AbstractFFTs"
5154
ReactantArrayInterfaceExt = "ArrayInterface"
5255
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
56+
ReactantDLFP8TypesExt = "DLFP8Types"
57+
ReactantFloat8sExt = "Float8s"
5358
ReactantKernelAbstractionsExt = "KernelAbstractions"
5459
ReactantMPIExt = "MPI"
5560
ReactantNNlibExt = ["NNlib", "Statistics"]
@@ -67,10 +72,12 @@ Adapt = "4.1"
6772
ArrayInterface = "7.17.1"
6873
CEnum = "0.5"
6974
CUDA = "5.6"
75+
DLFP8Types = "0.1"
7076
Downloads = "1.6"
7177
EnumX = "1"
7278
Enzyme = "0.13.49"
7379
EnzymeCore = "0.8.11"
80+
Float8s = "0.1"
7481
Functors = "0.5"
7582
GPUArraysCore = "0.2"
7683
GPUCompiler = "1.3"
@@ -90,12 +97,13 @@ PythonCall = "0.9.25"
9097
Random = "1.10"
9198
Random123 = "1.7"
9299
ReactantCore = "0.1.15"
93-
Reactant_jll = "0.0.219"
100+
Reactant_jll = "0.0.224"
94101
ScopedValues = "1.3.0"
95102
Scratch = "1.2"
96103
Sockets = "1.10"
97104
SpecialFunctions = "2.4"
98105
Statistics = "1.10"
106+
unzip_jll = "6"
99107
YaoBlocks = "0.13, 0.14"
100108
julia = "1.10"
101109

deps/ReactantExtra/.bazelrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ build --announce_rc
22

33
# TODO: Migrate for https://github.com/bazelbuild/bazel/issues/7260
44
common --noincompatible_enable_cc_toolchain_resolution
5+
common --repo_env USE_HERMETIC_CC_TOOLCHAIN=0
56
common --experimental_repo_remote_exec
67
common --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
78
common --cxxopt=-w --host_cxxopt=-w
89
common --define=grpc_no_ares=true
910
common --noenable_bzlmod
1011

12+
1113
build --repo_env=USE_PYWRAP_RULES=True
1214
build --copt=-DGRPC_BAZEL_BUILD
1315
build --host_copt=-DGRPC_BAZEL_BUILD
@@ -27,6 +29,7 @@ build:cuda --repo_env TF_NVCC_CLANG=1
2729
build:cuda --repo_env TF_NCCL_USE_STUB=1
2830
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.1"
2931
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0"
32+
build:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5"
3033
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
3134
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
3235
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90"
@@ -35,6 +38,8 @@ build:cuda --@local_config_cuda//:enable_cuda
3538
# Default hermetic CUDA and CUDNN versions.
3639
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
3740
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
41+
# build:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true
42+
# build:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true
3843

3944
build:rocm --repo_env TF_NEED_ROCM=1
4045
build:rocm --define=using_rocm=true

deps/ReactantExtra/BUILD

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
22
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
3+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
34
load("@xla//tools/toolchains/cross_compile/cc:cc_toolchain_config.bzl", "cc_toolchain_config")
45

56
# load("//toolchain:yggdrasil.bzl", "ygg_cc_toolchain")
@@ -752,6 +753,22 @@ platform(
752753
],
753754
)
754755

756+
platform(
757+
name = "win_x86_64",
758+
constraint_values = [
759+
"@platforms//os:linux",
760+
"@platforms//cpu:x86_64",
761+
],
762+
)
763+
764+
platform(
765+
name = "win_aarch64",
766+
constraint_values = [
767+
"@platforms//os:linux",
768+
"@platforms//cpu:aarch64",
769+
],
770+
)
771+
755772
cc_library(
756773
name = "ReactantExtraLib",
757774
srcs = glob(
@@ -777,12 +794,7 @@ cc_library(
777794
"-Werror=return-type",
778795
"-Werror=unused-result",
779796
"-Wno-error=stringop-truncation",
780-
] + select({
781-
"@xla//xla/tsl:is_cuda_enabled_and_oss": [
782-
"-DREACTANT_CUDA=1",
783-
],
784-
"//conditions:default": [],
785-
}),
797+
] + if_cuda(["-DREACTANT_CUDA=1"]),
786798
linkopts = select({
787799
"//conditions:default": [],
788800
"@bazel_tools//src/conditions:darwin": [
@@ -795,6 +807,9 @@ cc_library(
795807
"-Wl,-exported_symbol,_SetModuleLogLevel",
796808
"-Wl,-exported_symbol,_GetDefaultTargetTriple",
797809
"-Wl,-exported_symbol,_enzymeActivityAttrGet",
810+
"-Wl,-exported_symbol,_UninitPJRTBuffer",
811+
"-Wl,-exported_symbol,_CopyToBuffer",
812+
"-Wl,-exported_symbol,_CopyFromBuffer",
798813
"-Wl,-exported_symbol,_MakeCPUClient",
799814
"-Wl,-exported_symbol,_MakeGPUClient",
800815
"-Wl,-exported_symbol,_MakeTPUClient",
@@ -1029,23 +1044,19 @@ cc_library(
10291044
"@xla//xla/tsl/platform:errors",
10301045
"@xla//xla/service:hlo_proto_cc_impl",
10311046
"@com_google_absl//absl/status:statusor",
1032-
] + select({
1033-
"@xla//xla/tsl:is_cuda_enabled_and_oss": [
1034-
"@jax//jaxlib/cuda:cuda_gpu_kernels",
1035-
"@xla//xla/backends/profiler:profiler_backends",
1036-
"@xla//xla/backends/profiler/gpu:device_tracer",
1037-
"@xla//xla/pjrt/c:pjrt_c_api_gpu_internal",
1038-
"@xla//xla/service/gpu:gpu_transfer_manager",
1039-
"@xla//xla/service/gpu:nvptx_compiler",
1040-
"@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
1041-
"@xla//xla/service/gpu/model:hlo_op_profiles",
1042-
"@xla//xla/stream_executor:cuda_platform",
1043-
"@xla//xla/stream_executor:kernel",
1044-
"@xla//xla/stream_executor/cuda:all_runtime",
1045-
],
1046-
"//conditions:default": [
1047-
],
1048-
}) + if_rocm([
1047+
] + if_cuda([
1048+
"@jax//jaxlib/cuda:cuda_gpu_kernels",
1049+
"@xla//xla/backends/profiler:profiler_backends",
1050+
"@xla//xla/backends/profiler/gpu:device_tracer",
1051+
"@xla//xla/pjrt/c:pjrt_c_api_gpu_internal",
1052+
"@xla//xla/service/gpu:gpu_transfer_manager",
1053+
"@xla//xla/service/gpu:nvptx_compiler",
1054+
"@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
1055+
"@xla//xla/service/gpu/model:hlo_op_profiles",
1056+
"@xla//xla/stream_executor:cuda_platform",
1057+
"@xla//xla/stream_executor:kernel",
1058+
"@xla//xla/stream_executor/cuda:all_runtime",
1059+
]) + if_rocm([
10491060
"@xla//xla/stream_executor:rocm_platform",
10501061
"@xla//xla/service/gpu:amdgpu_compiler",
10511062
"@xla//xla/backends/profiler/gpu:device_tracer",

deps/ReactantExtra/WORKSPACE

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ http_archive(
1111
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1212
)
1313

14-
ENZYMEXLA_COMMIT = "6774f1afb90c377bbf234a7a7dbfab4f7b726481"
14+
ENZYMEXLA_COMMIT = "0dc3ef87806ab3c9a695fe5c5689f8e2baf0d6cb"
1515

1616
ENZYMEXLA_SHA256 = ""
1717

@@ -68,7 +68,10 @@ CUPTI_NEW = []
6868

6969
XLA_PATCHES = XLA_PATCHES + CUPTI_NEW + [
7070
"""
71-
sed -i.bak0 "s/kSupportedOpcodes({/kSupportedOpcodes(absl::flat_hash_set<HloOpcode>{/g" xla/service/gpu/gpu_memory_space_assignment.h
71+
sed -i.bak0 "s/kDeprecatedFlags({/kDeprecatedFlags(absl::flat_hash_set<std::string>{/g" xla/debug_options_flags.cc
72+
""",
73+
"""
74+
sed -i.bak0 "s/kStableFlags({/kStableFlags(absl::flat_hash_set<std::string>{/g" xla/debug_options_flags.cc
7275
""",
7376
"""
7477
sed -i.bak0 "s/cupti_driver_cbid/cupti/g" xla/backends/profiler/gpu/cupti_tracer.cc
@@ -102,17 +105,10 @@ sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.
102105
# """,
103106
]
104107

105-
LLVM_TARGETS = select({
106-
"@bazel_tools//src/conditions:windows": [
107-
"AMDGPU",
108-
"NVPTX",
109-
],
110-
"@bazel_tools//src/conditions:darwin": [],
111-
"//conditions:default": [
112-
"AMDGPU",
113-
"NVPTX",
114-
],
115-
}) + [
108+
LLVM_TARGETS = [
109+
"AMDGPU",
110+
"NVPTX",
111+
] + [
116112
"AArch64",
117113
"X86",
118114
"ARM",
@@ -237,6 +233,17 @@ load("@jax//third_party/xla:workspace.bzl", jax_xla_workspace = "repo")
237233

238234
jax_xla_workspace()
239235

236+
load("@xla//third_party/llvm:workspace.bzl", llvm = "repo")
237+
238+
llvm("llvm-raw")
239+
240+
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
241+
242+
llvm_configure(
243+
name = "llvm-project",
244+
targets = LLVM_TARGETS,
245+
)
246+
240247
load("@xla//:workspace4.bzl", "xla_workspace4")
241248

242249
xla_workspace4()
@@ -245,14 +252,8 @@ load("@xla//:workspace3.bzl", "xla_workspace3")
245252

246253
xla_workspace3()
247254

248-
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
249255
load("@xla//:workspace2.bzl", "xla_workspace2")
250256

251-
llvm_configure(
252-
name = "llvm-project",
253-
targets = LLVM_TARGETS,
254-
)
255-
256257
xla_workspace2()
257258

258259
load("@xla//:workspace1.bzl", "xla_workspace1")
@@ -285,7 +286,18 @@ load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
285286
flatbuffers()
286287

287288
load(
288-
"@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
289+
"@rules_ml_toolchain//cc_toolchain/deps:cc_toolchain_deps.bzl",
290+
"cc_toolchain_deps",
291+
)
292+
293+
cc_toolchain_deps()
294+
295+
register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64")
296+
297+
register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64_cuda")
298+
299+
load(
300+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
289301
"cuda_json_init_repository",
290302
)
291303

@@ -297,7 +309,12 @@ load(
297309
"CUDNN_REDISTRIBUTIONS",
298310
)
299311
load(
300-
"@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
312+
"@cuda_redist_json//:distributions.bzl",
313+
"CUDA_REDISTRIBUTIONS",
314+
"CUDNN_REDISTRIBUTIONS",
315+
)
316+
load(
317+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
301318
"cuda_redist_init_repositories",
302319
"cudnn_redist_init_repository",
303320
)
@@ -311,28 +328,28 @@ cudnn_redist_init_repository(
311328
)
312329

313330
load(
314-
"@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
331+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
315332
"cuda_configure",
316333
)
317334

318335
cuda_configure(name = "local_config_cuda")
319336

320337
load(
321-
"@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
338+
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
322339
"nccl_redist_init_repository",
323340
)
324341

325342
nccl_redist_init_repository()
326343

327344
load(
328-
"@xla//third_party/nccl/hermetic:nccl_configure.bzl",
345+
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_configure.bzl",
329346
"nccl_configure",
330347
)
331348

332349
nccl_configure(name = "local_config_nccl")
333350

334351
load(
335-
"@xla//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl",
352+
"@rules_ml_toolchain//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl",
336353
"nvshmem_json_init_repository",
337354
)
338355

@@ -343,17 +360,10 @@ load(
343360
"NVSHMEM_REDISTRIBUTIONS",
344361
)
345362
load(
346-
"@xla//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl",
363+
"@rules_ml_toolchain//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl",
347364
"nvshmem_redist_init_repository",
348365
)
349366

350367
nvshmem_redist_init_repository(
351368
nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS,
352369
)
353-
354-
load(
355-
"@xla//third_party/nvshmem/hermetic:nvshmem_configure.bzl",
356-
"nvshmem_configure",
357-
)
358-
359-
nvshmem_configure(name = "local_config_nvshmem")

0 commit comments

Comments
 (0)