|
| 1 | +# [Control Flow](@id control-flow) |
| 2 | + |
| 3 | +Reactant currently uses a tracing system to capture array operations into a new |
| 4 | +program. As such, the provided function is executed with [`TracedRArray`](@ref) |
| 5 | +as inputs instead of [`ConcreteRArray`](@ref). This means that during tracing |
| 6 | +only operations affecting such arrays are captured by Reactant into the new |
| 7 | +program. |
| 8 | + |
| 9 | +In practice, this means that Julia native control flow constructs are not |
| 10 | +captured by Reactant. |
| 11 | + |
| 12 | +Consider the following function which has a conditional control flow depending |
| 13 | +on one of its argument which is a boolean: |
| 14 | + |
| 15 | +```@example control_flow_tutorial |
| 16 | +using Reactant |
| 17 | +
|
| 18 | +function maybe_square(cond, x) |
| 19 | + if cond |
| 20 | + x = x .^ 2 |
| 21 | + else |
| 22 | + x = x |
| 23 | + end |
| 24 | + return x |
| 25 | +end |
| 26 | +``` |
| 27 | + |
| 28 | +We can confirm by compiling our function and noticing that the result does not |
| 29 | +depend on the argument provided to the compiled function. |
| 30 | + |
| 31 | +```@example control_flow_tutorial |
| 32 | +x = Reactant.ConcreteRArray(randn(Float32, 100)) |
| 33 | +
|
| 34 | +maybe_square_compiled = @compile maybe_square(true, x) |
| 35 | +maybe_square_compiled(false, x) == maybe_square_compiled(true, x) |
| 36 | +``` |
| 37 | + |
| 38 | +But instead, it depends on the value that was provided during tracing to the |
| 39 | +initial `@compile` invocation. This is also confirmed when looking at the |
| 40 | +code generated during tracing which does not contain any conditional. |
| 41 | + |
| 42 | +```@example control_flow_tutorial |
| 43 | +@code_hlo maybe_square(false, x) |
| 44 | +``` |
| 45 | + |
| 46 | +The same behaviour can be observed when using loops. In the following example, |
| 47 | +the loop is "unrolled" because it is not captured in the program. The optimizer |
| 48 | +then fuses all additions to add `n = 10` directly to the argument. |
| 49 | + |
| 50 | +```@example control_flow_tutorial |
| 51 | +function add_n(x, n) |
| 52 | + for _ in 1:n |
| 53 | + x .+= 1 |
| 54 | + end |
| 55 | + return x |
| 56 | +end |
| 57 | +
|
| 58 | +x = Reactant.to_rarray(zeros(Int, 10)) |
| 59 | +n = 10 |
| 60 | +@code_hlo add_n(x, n) |
| 61 | +``` |
| 62 | + |
| 63 | +In the next section, we will see what mechanism Reactant offers to integrate |
| 64 | +data-dependent control flow in the captured programs. |
| 65 | + |
| 66 | +## Data-dependent Control Flow using [`@trace`](@ref) |
| 67 | + |
| 68 | +During tracing the arrays contain no data and only information about their shape |
| 69 | +and data type. As such, it is not possible to execute conditions that would |
| 70 | +depend on the value of an array. For these cases, ReactantCore provides the |
| 71 | +[`@trace`](@ref) macro to allow capturing control flow expressions in the |
| 72 | +compiled program. |
| 73 | + |
| 74 | +### Conditional Control Flow |
| 75 | + |
| 76 | +Taking our same function from before and adding the [`@trace`](@ref) macro |
| 77 | +before the if expression will allow our compiled function to contain the |
| 78 | +condition. |
| 79 | + |
| 80 | +```@example control_flow_tutorial |
| 81 | +using Reactant |
| 82 | +
|
| 83 | +function maybe_square(cond, x) |
| 84 | + @trace if cond |
| 85 | + x = x ^ 2 |
| 86 | + else |
| 87 | + x = x |
| 88 | + end |
| 89 | + return x |
| 90 | +end |
| 91 | +``` |
| 92 | + |
| 93 | +First, let's note that [`@trace`](@ref) has no impact when the program is not |
| 94 | +run in a Reactant trace. As such, the function can still be used with plain |
| 95 | +Julia values. That makes it possible to include `@trace` in library code. |
| 96 | + |
| 97 | +```@example control_flow_tutorial |
| 98 | +x = 2. |
| 99 | +maybe_square(true, x) == x ^ 2 |
| 100 | +``` |
| 101 | + |
| 102 | +Then in our compiled version, we can pass a Reactant concrete boolean to |
| 103 | +conditionally control the output of the function. |
| 104 | + |
| 105 | +```@example control_flow_tutorial |
| 106 | +cond = Reactant.ConcreteRNumber(true) |
| 107 | +x = Reactant.ConcreteRNumber(2.) |
| 108 | +
|
| 109 | +@jit(maybe_square(cond, x))[1] == Reactant.ConcreteRNumber(4.) |
| 110 | +``` |
| 111 | + |
| 112 | +This can also be confirmed by looking at the generated MLIR code which |
| 113 | +will contain a `stablehlo.if` operation: |
| 114 | + |
| 115 | +```@example control_flow_tutorial |
| 116 | +@code_hlo maybe_square(cond, x) |
| 117 | +``` |
| 118 | + |
| 119 | +In our simple example, the condition is passed directly as an argument but |
| 120 | +the same mechanism is applied to conditions which are computed from within |
| 121 | +a function from traced arguments, leading to a traced condition. |
| 122 | + |
| 123 | +### Loops |
| 124 | + |
| 125 | +In addition to conditional evaluations, [`@trace`](@ref) also supports capturing |
| 126 | +loops. This is possible in the form of both for and while loops. |
| 127 | +This enables one to write algorithm that would not be possible otherwise such as |
| 128 | +performing computations until convergence or running a computation for an certain |
| 129 | +number of iterations which is only known during runtime. |
| 130 | + |
| 131 | +Here is an example of a function which computes the cumsum in non-optimized manner |
| 132 | +using a for loop: |
| 133 | + |
| 134 | +```@example control_flow_tutorial |
| 135 | +function cumsum!(x) |
| 136 | + v = zero(eltype(x)) |
| 137 | + @trace for i in eachindex(x) |
| 138 | + v += @allowscalar x[i] |
| 139 | + @allowscalar x[i] = v |
| 140 | + end |
| 141 | + return x |
| 142 | +end |
| 143 | +
|
| 144 | +x = Reactant.to_rarray([1., 2., 3.]) |
| 145 | +@jit(cumsum!(x)) |
| 146 | +``` |
| 147 | + |
| 148 | +Similarly, one can trace while loops. The following is a minimal implementation of the |
| 149 | +[Sinkhorn-Knopp algorithm]() which aims to solve the entropic optimal transport problem: |
| 150 | + |
| 151 | +```@example control_flow_tutorial |
| 152 | +using LinearAlgebra: Diagonal |
| 153 | +
|
| 154 | +function sinkhorn(μ, ν, C) |
| 155 | + λ = eltype(C)(0.8) |
| 156 | + K = @. exp(-C / λ) |
| 157 | +
|
| 158 | + u = fill!(similar(μ), one(eltype(μ))) |
| 159 | + v = similar(ν) |
| 160 | +
|
| 161 | + π = Diagonal(u) * K * Diagonal(v) |
| 162 | + err = typemax(eltype(π)) |
| 163 | +
|
| 164 | + @trace while err >= 0.001 |
| 165 | + v = ν ./ (K' * u) |
| 166 | + u = μ ./ (K * v) |
| 167 | +
|
| 168 | + new_π = Diagonal(u) * K * Diagonal(v) |
| 169 | + err = sum(abs2, new_π .- π) |
| 170 | + π = new_π |
| 171 | + end |
| 172 | +
|
| 173 | + return π |
| 174 | +end |
| 175 | +
|
| 176 | +a = Reactant.to_rarray(ones(Float32, 10) ./ 10) |
| 177 | +b = Reactant.to_rarray(ones(Float32, 12) ./ 12) |
| 178 | +C = Reactant.to_rarray(randn(Float32, 10, 12)) |
| 179 | +
|
| 180 | +π = @jit sinkhorn(a, b, C) |
| 181 | +
|
| 182 | +# The sum of the transport plan is 1. |
| 183 | +sum(π) |
| 184 | +``` |
| 185 | + |
| 186 | +This implementation runs the algorithm until convergence (the transport plan has seen little change in the last iteration). Without [`@trace`](@ref) this would not be possible to implement since the termination condition is depending on traced values (in this case, the value of the transport plan). |
| 187 | + |
| 188 | +!!! warning "Current limitations" |
| 189 | + |
| 190 | + This is currently not allowed to include mutations as part of the while loop condition. |
| 191 | + |
| 192 | + The for loop tracing does not support any arbitrary iterable. It supports integer ranges. |
0 commit comments