Skip to content

Commit 4fe55a6

Browse files
committed
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into probprog-trace-operand
2 parents e647b0d + 166f670 commit 4fe55a6

File tree

16 files changed

+1075
-103
lines changed

16 files changed

+1075
-103
lines changed

.github/workflows/CI-localjll.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ jobs:
3939
env:
4040
TMPDIR: ${{ github.workspace }}/tmp
4141
steps:
42+
- name: Free Disk Space
43+
uses: jlumbroso/free-disk-space@main
44+
with:
45+
tool-cache: false
46+
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
47+
- name: Clean `/opt`
48+
run: sudo rm -rf /opt/*
4249
- uses: actions/checkout@v4
4350
- name: Create TMPDIR
4451
run: |

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.145"
4+
version = "0.2.146"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -90,7 +90,7 @@ PythonCall = "0.9.25"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.15"
93-
Reactant_jll = "0.0.217"
93+
Reactant_jll = "0.0.219"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"

deps/ReactantExtra/API.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
8282

8383
#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h"
84+
#include "xla/hlo/translate/stablehlo.h"
8485

8586
// CPU collectives
8687
#include "xla/backends/cpu/collectives/mpi_collectives.h"
@@ -2956,3 +2957,10 @@ extern "C" void reactantXLAExec(LinkableRuntime **__restrict__ lrtP,
29562957
}
29572958
}
29582959
}
2960+
2961+
extern "C" HeldHloModule *convertMlirModuleToHloModule(MlirModule mod) {
2962+
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap(mod));
2963+
std::shared_ptr<xla::HloModule> hlo_module =
2964+
std::move(MyValueOrThrow(xla::ConvertStablehloToHlo(cmod_op)));
2965+
return reactant::capture(hlo_module);
2966+
}

deps/ReactantExtra/BUILD

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,7 @@ cc_library(
890890
"-Wl,-exported_symbol,_addSdyPropagationPipeline",
891891
"-Wl,-exported_symbol,_mlirGetFunctionTypeFromOperation",
892892
"-Wl,-exported_symbol,_mlirIsFunctionOpInterface",
893+
"-Wl,-exported_symbol,_convertMlirModuleToHloModule",
893894
],
894895
}),
895896
linkstatic = True,
@@ -960,6 +961,7 @@ cc_library(
960961
"@xla//xla/pjrt:status_casters",
961962
"@xla//xla/python/ifrt",
962963
"@xla//xla/python/pjrt_ifrt",
964+
"@xla//xla/hlo/translate:stablehlo",
963965
"@xla//xla/python/ifrt_proxy/server:grpc_server",
964966
"@xla//xla/python/ifrt_proxy/client:grpc_client",
965967
"@xla//xla/python/ifrt_proxy/client:registry",
@@ -1310,6 +1312,24 @@ gentbl_cc_library(
13101312
],
13111313
)
13121314

1315+
gentbl_cc_library(
1316+
name = "MosaicGPUJLIncGen",
1317+
tbl_outs = [
1318+
(
1319+
[
1320+
"--generator=jl-op-defs",
1321+
"--disable-module-wrap=0",
1322+
],
1323+
"MosaicGPU.jl",
1324+
),
1325+
],
1326+
tblgen = "//:mlir-jl-tblgen",
1327+
td_file = "@jax//jaxlib/mosaic/dialect/gpu:mosaic_gpu.td",
1328+
deps = [
1329+
"@jax//jaxlib/mosaic/dialect/gpu:mosaic_gpu_td_files",
1330+
],
1331+
)
1332+
13131333
gentbl_cc_library(
13141334
name = "TritonJLIncGen",
13151335
tbl_outs = [
@@ -1440,14 +1460,19 @@ genrule(
14401460
"@stablehlo//:stablehlo/integrations/c/StablehloDialectApi.h",
14411461
"@stablehlo//:stablehlo/integrations/c/StablehloTypes.h",
14421462
"@shardy//shardy/integrations/c:attributes.h",
1463+
"@jax//jaxlib/triton:triton_dialect_capi.h",
1464+
"@jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h",
1465+
"@jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_passes.capi.h.inc",
1466+
"@jax//jaxlib/mosaic/dialect/gpu:integrations/c/attributes.h",
1467+
"@jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h",
14431468
"//:Project.toml",
14441469
"//:Manifest.toml",
14451470
"//:wrap.toml",
14461471
"//:missing_defs.jl",
14471472
"//:make.jl",
14481473
],
14491474
outs = ["libMLIR_h.jl"],
1450-
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"",
1475+
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$(location @jax//jaxlib/triton:triton_dialect_capi.h)\" \"$(location @jax//jaxlib/mosaic:dialect/tpu/integrations/c/tpu_dialect.h)\" \"$(location @jax//jaxlib/mosaic/dialect/gpu:integrations/c/gpu_dialect.h)\" \"$@\"",
14511476
tags = [
14521477
"jlrule",
14531478
],

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ http_archive(
1111
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1212
)
1313

14-
ENZYMEXLA_COMMIT = "20aff0cfd430339ca01c9febb96675d62a4a7995"
14+
ENZYMEXLA_COMMIT = "6774f1afb90c377bbf234a7a7dbfab4f7b726481"
1515

1616
ENZYMEXLA_SHA256 = ""
1717

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ for file in [
3636
"Gpu.jl",
3737
"Affine.jl",
3838
"TPU.jl",
39+
"MosaicGPU.jl",
3940
"Triton.jl",
4041
"Shardy.jl",
4142
"MPI.jl",

deps/ReactantExtra/make.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ let options = deepcopy(options)
1919
genarg = first(eachsplit(ARGS[3], " "))
2020

2121
gen_include_dir = joinpath(splitpath(genarg)[1:(end - 4)]...)
22-
23-
hlo_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)
24-
25-
sdy_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)
22+
hlo_include_dir = joinpath(splitpath(ARGS[end - 5])[1:(end - 1)]...)
23+
sdy_include_dir = joinpath(splitpath(ARGS[end - 4])[1:(end - 1)]...)
24+
triton_include_dir = joinpath(splitpath(ARGS[end - 3])[1:(end - 1)]...)
25+
mosaic_tpu_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)
26+
mosaic_gpu_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)
2627

2728
append!(
2829
args,
@@ -37,6 +38,12 @@ let options = deepcopy(options)
3738
hlo_include_dir,
3839
"-I",
3940
sdy_include_dir,
41+
"-I",
42+
triton_include_dir,
43+
"-I",
44+
mosaic_tpu_include_dir,
45+
"-I",
46+
mosaic_gpu_include_dir,
4047
"-x",
4148
"c++",
4249
],
@@ -46,6 +53,9 @@ let options = deepcopy(options)
4653
detect_headers(include_dir, args, Dict(), endswith("Python/Interop.h"))...,
4754
detect_headers(hlo_include_dir, args, Dict())...,
4855
detect_headers(sdy_include_dir, args, Dict())...,
56+
detect_headers(triton_include_dir, args, Dict())...,
57+
detect_headers(mosaic_tpu_include_dir, args, Dict())...,
58+
detect_headers(mosaic_gpu_include_dir, args, Dict())...,
4959
]
5060

5161
ctx = create_context(headers, args, options)

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ export default defineConfig({
112112
{ text: "LLVM", link: "/api/dialects/llvm" },
113113
{ text: "MPI", link: "/api/dialects/mpi" },
114114
{ text: "MemRef", link: "/api/dialects/memref" },
115+
{ text: "Mosaic GPU", link: "/api/dialects/mosaicgpu" },
115116
{ text: "NVVM", link: "/api/dialects/nvvm" },
116117
{ text: "Shardy", link: "/api/dialects/shardy" },
117118
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },
@@ -188,6 +189,7 @@ export default defineConfig({
188189
{ text: "LLVM", link: "/api/dialects/llvm" },
189190
{ text: "MPI", link: "/api/dialects/mpi" },
190191
{ text: "MemRef", link: "/api/dialects/memref" },
192+
{ text: "Mosaic GPU", link: "/api/dialects/mosaicgpu" },
191193
{ text: "NVVM", link: "/api/dialects/nvvm" },
192194
{ text: "Shardy", link: "/api/dialects/shardy" },
193195
{ text: "SparseTensor", link: "/api/dialects/sparsetensor" },

docs/src/api/dialects/mosaicgpu.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# Mosaic GPU Dialect
6+
7+
```@autodocs
8+
Modules = [Reactant.MLIR.Dialects.mosaic_gpu]
9+
```

0 commit comments

Comments
 (0)