Skip to content

Commit 6cd102f

Browse files
authored
feat: export reactant compiled functions as tf SavedModel (#1426)
* fix: don't initialize jax arrays for tracing * fix: disable tf gpus * chore: setup code for serialization * feat: mostly working * feat: working saved model export * chore: run fmt * docs: serialization to savedmodel docs * chore: fmt * test: serialization * Update Project.toml
1 parent 608cf11 commit 6cd102f

File tree

15 files changed

+529
-53
lines changed

15 files changed

+529
-53
lines changed

CondaPkg.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
[pip.deps]
2-
jax = ">=0.4"
2+
jax = ">= 0.6"
3+
tensorflow = ">= 2.17"
4+
numpy = ">= 2"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ OneHotArrays = "0.2.10"
8686
OrderedCollections = "1"
8787
PrecompileTools = "1.2"
8888
Preferences = "1.4"
89-
PythonCall = "0.9"
89+
PythonCall = "0.9.25"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.15"

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ export default defineConfig({
9595
items: [
9696
{ text: "Core Reactant API", link: "/api/api" },
9797
{ text: "Sharding", link: "/api/sharding" },
98+
{ text: "Serialization", link: "/api/serialization" },
9899
{ text: "Ops", link: "/api/ops" },
99100
{ text: "Configuration", link: "/api/config" },
100101
{
@@ -169,6 +170,7 @@ export default defineConfig({
169170
link: "/api/api",
170171
},
171172
{ text: "Sharding", link: "/api/sharding" },
173+
{ text: "Serialization", link: "/api/serialization" },
172174
{ text: "Ops", link: "/api/ops" },
173175
{ text: "Configuration", link: "/api/config" },
174176
{

docs/src/api/serialization.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# Serialization
6+
7+
```@docs
8+
Reactant.Serialization
9+
```
10+
11+
## Exporting to TensorFlow SavedModel
12+
13+
!!! note "Load PythonCall"
14+
15+
Serialization to TensorFlow SavedModel requires PythonCall to be loaded. Loading
16+
PythonCall will automatically install tensorflow. If tensorflow installation fails,
17+
we won't be able to export to SavedModel.
18+
19+
A SavedModel contains a complete TensorFlow program, including trained parameters (i.e,
20+
tf.Variables) and computation. It does not require the original model building code to run,
21+
which makes it useful for sharing or deploying with [TFLite](https://tensorflow.org/lite),
22+
[TensorFlow.js](https://js.tensorflow.org/),
23+
[TensorFlow Serving](https://www.tensorflow.org/tfx/serving/tutorials/Serving_REST_simple),
24+
or [TensorFlow Hub](https://tensorflow.org/hub). Refer to the
25+
[official documentation](https://www.tensorflow.org/guide/saved_model) for more details.
26+
27+
```@docs
28+
Reactant.Serialization.export_as_tf_saved_model
29+
```

ext/ReactantPythonCallExt.jl

Lines changed: 0 additions & 49 deletions
This file was deleted.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
module ReactantPythonCallExt
2+
3+
using PythonCall
4+
using Reactant: Reactant, TracedRArray
5+
6+
const jaxptr = Ref{Py}()
7+
const jnpptr = Ref{Py}()
8+
9+
const JAX_TRACING_SUPPORTED = Ref{Bool}(false)
10+
11+
const tfptr = Ref{Py}()
12+
const tf2xlaptr = Ref{Py}()
13+
const npptr = Ref{Py}()
14+
15+
const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false)
16+
17+
const NUMPY_SIMPLE_TYPES = Dict(
18+
Bool => :bool,
19+
Int8 => :int8,
20+
Int16 => :int16,
21+
Int32 => :int32,
22+
Int64 => :int64,
23+
UInt8 => :uint8,
24+
UInt16 => :uint16,
25+
UInt32 => :uint32,
26+
UInt64 => :uint64,
27+
Float16 => :float16,
28+
Float32 => :float32,
29+
Float64 => :float64,
30+
ComplexF16 => :complex16,
31+
ComplexF32 => :complex32,
32+
ComplexF64 => :complex64,
33+
)
34+
35+
function __init__()
36+
try
37+
jaxptr[] = pyimport("jax")
38+
jnpptr[] = pyimport("jax.numpy")
39+
JAX_TRACING_SUPPORTED[] = true
40+
catch err
41+
@warn "Failed to import jax. Tracing jax functions invoked with pycall won't \
42+
be supported." exception = (err, catch_backtrace())
43+
end
44+
45+
try
46+
tfptr[] = pyimport("tensorflow")
47+
tfptr[].config.set_visible_devices(pylist(); device_type="GPU")
48+
tf2xlaptr[] = pyimport("tensorflow.compiler.tf2xla.python.xla")
49+
npptr[] = pyimport("numpy")
50+
SAVED_MODEL_EXPORT_SUPPORTED[] = true
51+
catch err
52+
@warn "Failed to import tensorflow. Exporting Reactant compiled functions as \
53+
tensorflow SavedModel will not be \
54+
supported." exception = (err, catch_backtrace())
55+
end
56+
return nothing
57+
end
58+
59+
include("pycall.jl")
60+
include("saved_model.jl")
61+
62+
end

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
function PythonCall.pycall(f::Py, arg0::TracedRArray, argNs::TracedRArray...; kwargs...)
2+
JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.")
3+
4+
jax = jaxptr[]
5+
jnp = jnpptr[]
6+
7+
inputs = map((arg0, argNs...)) do arg
8+
jax.ShapeDtypeStruct(
9+
size(arg),
10+
jnp.dtype(string(NUMPY_SIMPLE_TYPES[Reactant.unwrapped_eltype(arg)])),
11+
)
12+
end
13+
14+
lowered = jax.jit(f).lower(inputs...)
15+
res = Reactant.Ops.hlo_call(pyconvert(String, lowered.as_text()), arg0, argNs...)
16+
17+
return length(res) == 0 ? nothing : res[1]
18+
end
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# TODO: at some point, we should use the TF C++ API to export the SavedModel
2+
3+
function Reactant.Serialization.serialization_supported(::Val{:SavedModel})
4+
return SAVED_MODEL_EXPORT_SUPPORTED[]
5+
end
6+
7+
function _extract_call_parameters(args::Tuple, input_locations, state_dict)
8+
call_args = pylist([])
9+
for loc in input_locations
10+
if loc isa Reactant.Serialization.TFSavedModel.InputArgument
11+
call_args.append(args[loc.position])
12+
else
13+
@assert haskey(state_dict, loc.name) "State dictionary does not contain key: \
14+
$(loc.name)"
15+
call_args.append(state_dict[loc.name])
16+
end
17+
end
18+
return call_args
19+
end
20+
21+
function _wrap_as_tf_func(
22+
spec::Reactant.Serialization.TFSavedModel.ReactantFunctionSpec, state_dict
23+
)
24+
Touts = pylist([string(sig.dtype) for sig in spec.output_signature])
25+
Souts = pylist([pylist(sig.shape) for sig in spec.output_signature])
26+
return pyfunc(
27+
function (args...)
28+
return tf2xlaptr[].call_module(
29+
_extract_call_parameters(args, spec.input_locations, state_dict);
30+
version=5,
31+
Tout=Touts, # dtype information
32+
Sout=Souts, # Shape information
33+
function_list=pylist([]), # No functions to call
34+
:module => spec.bytecode,
35+
)
36+
end,
37+
)
38+
end
39+
40+
function _make_input_signatures(
41+
fn_spec::Reactant.Serialization.TFSavedModel.ReactantFunctionSpec
42+
)
43+
input_pos_to_spec = Dict(
44+
loc.position => spec for
45+
(loc, spec) in zip(fn_spec.input_locations, fn_spec.input_signature) if
46+
loc isa Reactant.Serialization.TFSavedModel.InputArgument
47+
)
48+
49+
sigs = []
50+
for i in 1:length(input_pos_to_spec)
51+
spec = input_pos_to_spec[i]
52+
dtype = getproperty(tfptr[], spec.dtype)
53+
push!(
54+
sigs,
55+
tfptr[].TensorSpec(;
56+
shape=pylist(spec.shape), dtype=dtype, name="args_$(i - 1)"
57+
),
58+
)
59+
end
60+
return pylist(sigs)
61+
end
62+
63+
function Reactant.Serialization.TFSavedModel.__to_tf_saved_model(
64+
fn_spec::Reactant.Serialization.TFSavedModel.ReactantFunctionSpec, path::String
65+
)
66+
tfm = tfptr[].Module()
67+
68+
state_dict = Dict(
69+
k => tfptr[].Variable(
70+
npptr[].asarray(permutedims(v, collect(ndims(v):-1:1)));
71+
trainable=false,
72+
name=k,
73+
) for (k, v) in fn_spec.state_dict
74+
)
75+
76+
input_signatures = _make_input_signatures(fn_spec)
77+
78+
tfm.f = getproperty(tfptr[], :function)(
79+
_wrap_as_tf_func(fn_spec, state_dict); input_signature=input_signatures
80+
)
81+
tfm._variables = pylist(collect(values(state_dict)))
82+
83+
signatures = pydict([
84+
tfptr[].saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY =>
85+
tfm.f.get_concrete_function(input_signatures...),
86+
])
87+
save_options = tfptr[].saved_model.SaveOptions(; function_aliases=pydict(["" => tfm.f]))
88+
89+
tfptr[].saved_model.save(tfm, path; signatures=signatures, options=save_options)
90+
91+
return nothing
92+
end

src/Compiler.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2449,11 +2449,16 @@ Compile the function `f` with arguments `args` and return the compiled function.
24492449
24502450
$(SYNC_DOCS)
24512451
$(COMMON_COMPILE_OPTIONS_DOCS)
2452+
- `serializable`: If `true`, the compiled function will be serializable. This is needed
2453+
for saving the compiled function to disk and loading it later. Defaults to `false`.
24522454
24532455
See also [`@jit`](@ref), [`@code_hlo`](@ref), [`@code_mhlo`](@ref), [`@code_xla`](@ref).
24542456
"""
24552457
macro compile(args...)
2456-
default_options = merge(get_common_compile_options(), Dict{Symbol,Any}(:sync => false))
2458+
default_options = merge(
2459+
get_common_compile_options(),
2460+
Dict{Symbol,Any}(:sync => false, :serializable => false),
2461+
)
24572462
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
24582463
end
24592464

@@ -3413,7 +3418,7 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs..
34133418
# XLA.compile mutates the module, for serialization we need to keep a copy
34143419
if serializable
34153420
iobuffer = IOBuffer()
3416-
show(IOContext(iobuffer, :debug => debug), mod)
3421+
show(IOContext(iobuffer, :debug => true), mod)
34173422
module_string = String(take!(iobuffer))
34183423
else
34193424
module_string = ""

src/Reactant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ include("Compiler.jl")
190190

191191
include("Overlay.jl")
192192

193+
# Serialization
194+
include("serialization/Serialization.jl")
195+
193196
using .Compiler: @compile, @code_hlo, @code_mhlo, @jit, @code_xla, traced_getfield, compile
194197
export ConcreteRArray,
195198
ConcreteRNumber,

0 commit comments

Comments
 (0)