Skip to content

Commit 7f611fe

Browse files
committed
Merge branch 'main' of https://github.com/EnzymeAD/Reactant.jl into probprog
2 parents 0f94166 + a874ef7 commit 7f611fe

29 files changed

+969
-126
lines changed

.github/workflows/format-pr-bazel.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ jobs:
2424
uses: peter-evans/create-pull-request@v7
2525
with:
2626
token: ${{ secrets.GITHUB_TOKEN }}
27-
commit-message: Format code
28-
title: 'Format code of branch "main"'
29-
branch: format-main
27+
commit-message: Format Bazel code
28+
title: 'Format Bazel code of branch "main"'
29+
branch: format-main-bazel
3030
delete-branch: true
3131
labels: format
3232
author: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>

.github/workflows/format-pr.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Format 'main'
1+
name: Format Julia code on 'main'
22

33
on:
44
schedule:
@@ -34,9 +34,9 @@ jobs:
3434
uses: peter-evans/create-pull-request@v7
3535
with:
3636
token: ${{ secrets.GITHUB_TOKEN }}
37-
commit-message: Format code
38-
title: 'Format code of branch "main"'
39-
branch: format-main
37+
commit-message: Format Julia code
38+
title: 'Format Julia code of branch "main"'
39+
branch: format-main-julia
4040
delete-branch: true
4141
labels: format
4242
author: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>

.github/workflows/regenerate-mlir-bindings.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,17 @@ jobs:
1616
with:
1717
version: '1.10'
1818
- uses: julia-actions/cache@v2
19-
- uses: actions/checkout@v4
19+
- uses: bazel-contrib/[email protected]
20+
name: Set up Bazel
2021
with:
21-
ref: main
22+
# Avoid downloading Bazel every time.
23+
bazelisk-cache: true
24+
# Store build cache per workflow.
25+
disk-cache: ${{ github.workflow }}-${{ matrix.os }}-${{ matrix.version }}
26+
# Share repository cache between workflows.
27+
repository-cache: true
28+
bazelisk-version: 1.x
29+
- uses: actions/checkout@v4
2230
- name: Install JuliaFormatter.jl
2331
shell: julia --color=yes {0}
2432
run: |

CondaPkg.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
[pip.deps]
2-
jax = ">=0.4"
2+
jax = ">= 0.6"
3+
tensorflow = ">= 2.17"
4+
numpy = ">= 2"

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.139"
4+
version = "0.2.143"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -86,11 +86,11 @@ OneHotArrays = "0.2.10"
8686
OrderedCollections = "1"
8787
PrecompileTools = "1.2"
8888
Preferences = "1.4"
89-
PythonCall = "0.9"
89+
PythonCall = "0.9.25"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.15"
93-
Reactant_jll = "0.0.213"
93+
Reactant_jll = "0.0.214"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"

deps/ReactantExtra/API.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,17 @@ extern "C" MlirOperation mlirOperationParse(MlirContext ctx, MlirBlock block,
270270
.release()};
271271
}
272272

273+
extern "C" MlirType mlirGetFunctionTypeFromOperation(MlirOperation op) {
274+
if (auto funcOp = dyn_cast<mlir::FunctionOpInterface>(unwrap(op))) {
275+
return wrap(funcOp.getFunctionType());
276+
}
277+
ReactantThrowError("Not a function op");
278+
}
279+
280+
extern "C" bool mlirIsFunctionOpInterface(MlirOperation op) {
281+
return llvm::isa<mlir::FunctionOpInterface>(unwrap(op));
282+
}
283+
273284
// TODO mlirComplexAttrGetnValue
274285
// TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return
275286
// wrap(complex::NumberAttr::getTypeID()); }

deps/ReactantExtra/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,8 @@ cc_library(
888888
"-Wl,-exported_symbol,_hlo_sharding_*",
889889
"-Wl,-exported_symbol,_free_ifrt_sharding",
890890
"-Wl,-exported_symbol,_addSdyPropagationPipeline",
891+
"-Wl,-exported_symbol,_mlirGetFunctionTypeFromOperation",
892+
"-Wl,-exported_symbol,_mlirIsFunctionOpInterface",
891893
],
892894
}),
893895
linkstatic = True,
@@ -1434,6 +1436,9 @@ genrule(
14341436
"@llvm-project//mlir:AsyncPassIncGen_filegroup",
14351437
"@llvm-project//mlir:GPUPassIncGen_filegroup",
14361438
"@stablehlo//:stablehlo/integrations/c/StablehloAttributes.h",
1439+
"@stablehlo//:stablehlo/integrations/c/StablehloDialect.h",
1440+
"@stablehlo//:stablehlo/integrations/c/StablehloDialectApi.h",
1441+
"@stablehlo//:stablehlo/integrations/c/StablehloTypes.h",
14371442
"@shardy//shardy/integrations/c:attributes.h",
14381443
"//:Project.toml",
14391444
"//:Manifest.toml",
@@ -1442,7 +1447,7 @@ genrule(
14421447
"//:make.jl",
14431448
],
14441449
outs = ["libMLIR_h.jl"],
1445-
cmd = "$$JULIA \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"",
1450+
cmd = "$$JULIA \"--color=yes\" \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"",
14461451
tags = [
14471452
"jlrule",
14481453
],

deps/ReactantExtra/WORKSPACE

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ http_archive(
1111
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1212
)
1313

14-
ENZYMEXLA_COMMIT = "e52dc6b096213e21940693cda860c99ada2de247"
14+
ENZYMEXLA_COMMIT = "2527ca4bb8fa4499cd10ffb42ce4c2cda3738e91"
1515

1616
ENZYMEXLA_SHA256 = ""
1717

@@ -52,7 +52,24 @@ hedron_compile_commands_setup_transitive_transitive_transitive()
5252

5353
load("@enzyme_ad//:workspace.bzl", "ENZYME_COMMIT", "ENZYME_SHA256", "JAX_COMMIT", "JAX_SHA256", "XLA_PATCHES")
5454

55-
XLA_PATCHES = XLA_PATCHES + [
55+
CUPTI_OLD = [
56+
"""
57+
sed -i.bak0 "s/cupti_driver_cbid/cupti/g" xla/backends/profiler/gpu/cupti_tracer.cc
58+
""",
59+
"""
60+
sed -i.bak0 "/CUPTI_DRIVER_TRACE_CBID_cuGraphAddNode/d" xla/backends/profiler/gpu/cupti_tracer.cc
61+
""",
62+
"""
63+
sed -i.bak0 "/CUPTI_DRIVER_TRACE_CBID_cuGraphAddNode_v2/d" xla/backends/profiler/gpu/cupti_tracer.cc
64+
""",
65+
]
66+
67+
CUPTI_NEW = []
68+
69+
XLA_PATCHES = XLA_PATCHES + CUPTI_NEW + [
70+
"""
71+
sed -i.bak0 "s/cupti_driver_cbid/cupti/g" xla/backends/profiler/gpu/cupti_tracer.cc
72+
""",
5673
"""
5774
sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config.bzl -exec sed -i.bak0 's\\/HAVE_LINK_H=1\\/HAVE_LINK_H=0\\/g' {} +\\\",/g" third_party/llvm/workspace.bzl
5875
""",

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ export default defineConfig({
9595
items: [
9696
{ text: "Core Reactant API", link: "/api/api" },
9797
{ text: "Sharding", link: "/api/sharding" },
98+
{ text: "Serialization", link: "/api/serialization" },
9899
{ text: "Ops", link: "/api/ops" },
99100
{ text: "Configuration", link: "/api/config" },
100101
{
@@ -169,6 +170,7 @@ export default defineConfig({
169170
link: "/api/api",
170171
},
171172
{ text: "Sharding", link: "/api/sharding" },
173+
{ text: "Serialization", link: "/api/serialization" },
172174
{ text: "Ops", link: "/api/ops" },
173175
{ text: "Configuration", link: "/api/config" },
174176
{

docs/src/api/serialization.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# Serialization
6+
7+
```@docs
8+
Reactant.Serialization
9+
```
10+
11+
## Exporting to TensorFlow SavedModel
12+
13+
!!! note "Load PythonCall"
14+
15+
Serialization to TensorFlow SavedModel requires PythonCall to be loaded. Loading
16+
PythonCall will automatically install tensorflow. If tensorflow installation fails,
17+
we won't be able to export to SavedModel.
18+
19+
A SavedModel contains a complete TensorFlow program, including trained parameters (i.e,
20+
tf.Variables) and computation. It does not require the original model building code to run,
21+
which makes it useful for sharing or deploying with [TFLite](https://tensorflow.org/lite),
22+
[TensorFlow.js](https://js.tensorflow.org/),
23+
[TensorFlow Serving](https://www.tensorflow.org/tfx/serving/tutorials/Serving_REST_simple),
24+
or [TensorFlow Hub](https://tensorflow.org/hub). Refer to the
25+
[official documentation](https://www.tensorflow.org/guide/saved_model) for more details.
26+
27+
```@docs
28+
Reactant.Serialization.export_as_tf_saved_model
29+
```

0 commit comments

Comments
 (0)