Skip to content

Commit f2ad3ef

Browse files
authored
@trace while and control flow tutorial (#1374)
* `@trace` while and control flow tutorial * link to tutorial * Update ReactantCore * Trace * update tutorial * tutorial: loops * limitations * Example using reactant
1 parent 24d0f22 commit f2ad3ef

File tree

8 files changed

+347
-81
lines changed

8 files changed

+347
-81
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Preferences = "1.4"
8989
PythonCall = "0.9"
9090
Random = "1.10"
9191
Random123 = "1.7"
92-
ReactantCore = "0.1.11"
92+
ReactantCore = "0.1.12"
9393
Reactant_jll = "0.0.198"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"

docs/src/.vitepress/config.mts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ export default defineConfig({
8787
{text: "Profiling", link: "/tutorials/profiling"},
8888
{text: "Distributed", link: "/tutorials/multihost"},
8989
{text: "Local build", link: "/tutorials/local-build"},
90+
{text: "Control Flow", link: "/tutorials/control-flow"},
9091
],
9192
},
9293
{

docs/src/tutorials/control-flow.md

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# [Control Flow](@id control-flow)
2+
3+
Reactant currently uses a tracing system to capture array operations into a new
4+
program. As such, the provided function is executed with [`TracedRArray`](@ref)
5+
as inputs instead of [`ConcreteRArray`](@ref). This means that during tracing
6+
only operations affecting such arrays are captured by Reactant into the new
7+
program.
8+
9+
In practice, this means that Julia native control flow constructs are not
10+
captured by Reactant.
11+
12+
Consider the following function which has a conditional control flow depending
13+
on one of its argument which is a boolean:
14+
15+
```@example control_flow_tutorial
16+
using Reactant
17+
18+
function maybe_square(cond, x)
19+
if cond
20+
x = x .^ 2
21+
else
22+
x = x
23+
end
24+
return x
25+
end
26+
```
27+
28+
We can confirm by compiling our function and noticing that the result does not
29+
depend on the argument provided to the compiled function.
30+
31+
```@example control_flow_tutorial
32+
x = Reactant.ConcreteRArray(randn(Float32, 100))
33+
34+
maybe_square_compiled = @compile maybe_square(true, x)
35+
maybe_square_compiled(false, x) == maybe_square_compiled(true, x)
36+
```
37+
38+
But instead, it depends on the value that was provided during tracing to the
39+
initial `@compile` invocation. This is also confirmed when looking at the
40+
code generated during tracing which does not contain any conditional.
41+
42+
```@example control_flow_tutorial
43+
@code_hlo maybe_square(false, x)
44+
```
45+
46+
The same behaviour can be observed when using loops. In the following example,
47+
the loop is "unrolled" because it is not captured in the program. The optimizer
48+
then fuses all additions to add `n = 10` directly to the argument.
49+
50+
```@example control_flow_tutorial
51+
function add_n(x, n)
52+
for _ in 1:n
53+
x .+= 1
54+
end
55+
return x
56+
end
57+
58+
x = Reactant.to_rarray(zeros(Int, 10))
59+
n = 10
60+
@code_hlo add_n(x, n)
61+
```
62+
63+
In the next section, we will see what mechanism Reactant offers to integrate
64+
data-dependent control flow in the captured programs.
65+
66+
## Data-dependent Control Flow using [`@trace`](@ref)
67+
68+
During tracing the arrays contain no data and only information about their shape
69+
and data type. As such, it is not possible to execute conditions that would
70+
depend on the value of an array. For these cases, ReactantCore provides the
71+
[`@trace`](@ref) macro to allow capturing control flow expressions in the
72+
compiled program.
73+
74+
### Conditional Control Flow
75+
76+
Taking our same function from before and adding the [`@trace`](@ref) macro
77+
before the if expression will allow our compiled function to contain the
78+
condition.
79+
80+
```@example control_flow_tutorial
81+
using Reactant
82+
83+
function maybe_square(cond, x)
84+
@trace if cond
85+
x = x ^ 2
86+
else
87+
x = x
88+
end
89+
return x
90+
end
91+
```
92+
93+
First, let's note that [`@trace`](@ref) has no impact when the program is not
94+
run in a Reactant trace. As such, the function can still be used with plain
95+
Julia values. That makes it possible to include `@trace` in library code.
96+
97+
```@example control_flow_tutorial
98+
x = 2.
99+
maybe_square(true, x) == x ^ 2
100+
```
101+
102+
Then in our compiled version, we can pass a Reactant concrete boolean to
103+
conditionally control the output of the function.
104+
105+
```@example control_flow_tutorial
106+
cond = Reactant.ConcreteRNumber(true)
107+
x = Reactant.ConcreteRNumber(2.)
108+
109+
@jit(maybe_square(cond, x))[1] == Reactant.ConcreteRNumber(4.)
110+
```
111+
112+
This can also be confirmed by looking at the generated MLIR code which
113+
will contain a `stablehlo.if` operation:
114+
115+
```@example control_flow_tutorial
116+
@code_hlo maybe_square(cond, x)
117+
```
118+
119+
In our simple example, the condition is passed directly as an argument but
120+
the same mechanism is applied to conditions which are computed from within
121+
a function from traced arguments, leading to a traced condition.
122+
123+
### Loops
124+
125+
In addition to conditional evaluations, [`@trace`](@ref) also supports capturing
126+
loops. This is possible in the form of both for and while loops.
127+
This enables one to write algorithm that would not be possible otherwise such as
128+
performing computations until convergence or running a computation for an certain
129+
number of iterations which is only known during runtime.
130+
131+
Here is an example of a function which computes the cumsum in non-optimized manner
132+
using a for loop:
133+
134+
```@example control_flow_tutorial
135+
function cumsum!(x)
136+
v = zero(eltype(x))
137+
@trace for i in eachindex(x)
138+
v += @allowscalar x[i]
139+
@allowscalar x[i] = v
140+
end
141+
return x
142+
end
143+
144+
x = Reactant.to_rarray([1., 2., 3.])
145+
@jit(cumsum!(x))
146+
```
147+
148+
Similarly, one can trace while loops. The following is a minimal implementation of the
149+
[Sinkhorn-Knopp algorithm]() which aims to solve the entropic optimal transport problem:
150+
151+
```@example control_flow_tutorial
152+
using LinearAlgebra: Diagonal
153+
154+
function sinkhorn(μ, ν, C)
155+
λ = eltype(C)(0.8)
156+
K = @. exp(-C / λ)
157+
158+
u = fill!(similar(μ), one(eltype(μ)))
159+
v = similar(ν)
160+
161+
π = Diagonal(u) * K * Diagonal(v)
162+
err = typemax(eltype(π))
163+
164+
@trace while err >= 0.001
165+
v = ν ./ (K' * u)
166+
u = μ ./ (K * v)
167+
168+
new_π = Diagonal(u) * K * Diagonal(v)
169+
err = sum(abs2, new_π .- π)
170+
π = new_π
171+
end
172+
173+
return π
174+
end
175+
176+
a = Reactant.to_rarray(ones(Float32, 10) ./ 10)
177+
b = Reactant.to_rarray(ones(Float32, 12) ./ 12)
178+
C = Reactant.to_rarray(randn(Float32, 10, 12))
179+
180+
π = @jit sinkhorn(a, b, C)
181+
182+
# The sum of the transport plan is 1.
183+
sum(π)
184+
```
185+
186+
This implementation runs the algorithm until convergence (the transport plan has seen little change in the last iteration). Without [`@trace`](@ref) this would not be possible to implement since the termination condition is depending on traced values (in this case, the value of the transport plan).
187+
188+
!!! warning "Current limitations"
189+
190+
This is currently not allowed to include mutations as part of the while loop condition.
191+
192+
The for loop tracing does not support any arbitrary iterable. It supports integer ranges.

docs/src/tutorials/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
- [Profiling](@ref profiling).
44
- [Multi-Host Environments](@ref distributed).
55
- [Local build of ReactantExtra](@ref local-build).
6+
- [Control flow](@ref control-flow).
67

78
We are currently working on adding more tutorials to Reactant!! Please check back soon!

lib/ReactantCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.1.11"
4+
version = "0.1.12"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

0 commit comments

Comments
 (0)