Skip to content

Commit a93a83a

Browse files
fix: consistency with NNlib (#1328)
* fix: conditionals in softmax * fix: default to HIGH precision * fix: revert high precision * docs: add a dedicated FAQs section in the docs * feat: add a convolution precision scopedvalue * Update test/config.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: config * chore: bump version for release --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent fbaaadd commit a93a83a

File tree

12 files changed

+168
-154
lines changed

12 files changed

+168
-154
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.116"
4+
version = "0.2.117"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ export default defineConfig({
7676
items: [
7777
{ text: "Introduction", link: "/introduction" },
7878
{ text: "Configuration", link: "/introduction/configuration" },
79+
{ text: "FAQs", link: "/introduction/FAQs" },
7980
],
8081
},
8182
{ text: "Benchmarks", link: "https://enzymead.github.io/Reactant.jl/benchmarks/" },
@@ -140,6 +141,7 @@ export default defineConfig({
140141
items: [
141142
{ text: "Introduction", link: "/introduction" },
142143
{ text: "Configuration", link: "/introduction/configuration" },
144+
{ text: "FAQs", link: "/introduction/FAQs" },
143145
],
144146
}
145147
],

docs/src/api/config.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Reactant.with_config
2929

3030
```@docs
3131
Reactant.DotGeneralAlgorithmPreset
32-
Reactant.DotGeneralPrecision
32+
Reactant.PrecisionConfig
3333
Reactant.DotGeneralAlgorithm
3434
```
3535

docs/src/introduction/FAQs.md

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# FAQs
2+
3+
## XLA auto-tuner: Results do not match the reference. This is likely a bug/unexpected loss of precision
4+
5+
If you see this error with the CUDA backend, use a scoped value to increase the precision
6+
of the dot-general algorithm.
7+
8+
```julia
9+
Reactant.with_config(; dot_general_precision=PrecisionConfig.HIGH) do
10+
@compile ...
11+
end
12+
```
13+
14+
For more information, see [this XLA issue](https://github.com/openxla/xla/issues/23934).
15+
16+
## Emptying the cache to avoid OOM issues
17+
18+
When you encounter OOM (Out of Memory) errors, you can try to clear the cache by using
19+
Julia's builtin `GC.gc()` between memory-intensive operations.
20+
21+
!!! note
22+
This will only free memory which is not currently live. If the result of compiled
23+
function was stored in a vector, it will still be alive and `GC.gc()` won't free it.
24+
25+
```julia
26+
using Reactant
27+
n = 500_000_000
28+
input1 = Reactant.ConcreteRArray(ones(n))
29+
input2 = Reactant.ConcreteRArray(ones(n))
30+
31+
function sin_add(x, y)
32+
return sin.(x) .+ y
33+
end
34+
35+
f = @compile sin_add(input1,input2)
36+
37+
for i = 1:10
38+
GC.gc()
39+
@info "gc... $i"
40+
f(input1, input2) # May cause OOM here for a 24GB GPU if GC is not used
41+
end
42+
```
43+
44+
If you **don't** use `GC.gc()` here, this may cause an OOM:
45+
46+
```bash
47+
[ Info: gc... 1
48+
[ Info: gc... 2
49+
[ Info: gc... 3
50+
...
51+
E0105 09:48:28.755177 110350 pjrt_stream_executor_client.cc:3088] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes.
52+
ERROR: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes.
53+
54+
Stacktrace:
55+
[1] reactant_err(msg::Cstring)
56+
@ Reactant.XLA ~/.julia/packages/Reactant/7m11i/src/XLA.jl:104
57+
[2] macro expansion
58+
@ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:357 [inlined]
59+
[3] ExecutableCall
60+
@ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:334 [inlined]
61+
[4] macro expansion
62+
@ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:798 [inlined]
63+
[5] (::Reactant.Compiler.Thunk{…})(::ConcreteRArray{…}, ::ConcreteRArray{…})
64+
@ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:909
65+
[6] top-level scope
66+
@ ./REPL[7]:4
67+
Some type information was truncated. Use `show(err)` to see complete types.
68+
```
69+
70+
After using Julia's built-in `GC.gc()`:
71+
72+
```bash
73+
[ Info: gc... 1
74+
[ Info: gc... 2
75+
[ Info: gc... 3
76+
[ Info: gc... 4
77+
[ Info: gc... 5
78+
[ Info: gc... 6
79+
[ Info: gc... 7
80+
[ Info: gc... 8
81+
[ Info: gc... 9
82+
[ Info: gc... 10
83+
```

docs/src/introduction/index.md

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -53,83 +53,3 @@ f = @compile sinsum_add(input1,input2)
5353
# one can now run the program
5454
f(input1, input2)
5555
```
56-
57-
58-
## Tips
59-
60-
### Empty Cache
61-
62-
When you encounter OOM (Out of Memory) errors, you can try to clear the cache by using Julia's builtin `GC.gc()` between memory-intensive operations.
63-
64-
!!! note
65-
This will only free memory which is not currently live. If the result of compiled function was stored in a vector, it will still be alive and `GC.gc()` won't free it.
66-
67-
```julia
68-
using Reactant
69-
n = 500_000_000
70-
input1 = Reactant.ConcreteRArray(ones(n))
71-
input2 = Reactant.ConcreteRArray(ones(n))
72-
73-
function sin_add(x, y)
74-
return sin.(x) .+ y
75-
end
76-
77-
f = @compile sin_add(input1,input2)
78-
79-
for i = 1:10
80-
GC.gc()
81-
@info "gc... $i"
82-
f(input1, input2) # May cause OOM here for a 24GB GPU if GC is not used
83-
end
84-
```
85-
86-
If you **don't** use `GC.gc()` here, this may cause an OOM:
87-
88-
89-
90-
```bash
91-
[ Info: gc... 1
92-
[ Info: gc... 2
93-
[ Info: gc... 3
94-
...
95-
E0105 09:48:28.755177 110350 pjrt_stream_executor_client.cc:3088] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes.
96-
ERROR: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4000000000 bytes.
97-
98-
Stacktrace:
99-
[1] reactant_err(msg::Cstring)
100-
@ Reactant.XLA ~/.julia/packages/Reactant/7m11i/src/XLA.jl:104
101-
[2] macro expansion
102-
@ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:357 [inlined]
103-
[3] ExecutableCall
104-
@ ~/.julia/packages/Reactant/7m11i/src/XLA.jl:334 [inlined]
105-
[4] macro expansion
106-
@ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:798 [inlined]
107-
[5] (::Reactant.Compiler.Thunk{…})(::ConcreteRArray{…}, ::ConcreteRArray{…})
108-
@ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:909
109-
[6] top-level scope
110-
@ ./REPL[7]:4
111-
Some type information was truncated. Use `show(err)` to see complete types.
112-
```
113-
114-
115-
After using Julia's built-in `GC.gc()`:
116-
117-
118-
119-
```bash
120-
[ Info: gc... 1
121-
[ Info: gc... 2
122-
[ Info: gc... 3
123-
[ Info: gc... 4
124-
[ Info: gc... 5
125-
[ Info: gc... 6
126-
[ Info: gc... 7
127-
[ Info: gc... 8
128-
[ Info: gc... 9
129-
[ Info: gc... 10
130-
```
131-
132-
133-
134-
135-

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,26 @@ for (jlop, hloop) in (
77
end
88

99
function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
10-
max_ = NNlib.fast_maximum(x; dims)
11-
# XXX: Once reverse mode of if is properly supported, we can make it @trace
12-
# zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0)
13-
# one_num = TracedUtils.promote_to(TracedRNumber{T}, 1)
14-
# @trace if all(isfinite, max_)
15-
@. out = exp(x - max_)
16-
# else
17-
# cond = max_ .== Inf
18-
# true_pred = ifelse.(x .== Inf, one_num, zero_num)
19-
# @. out = ifelse(cond, true_pred, exp(x - max_))
20-
# end
21-
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
22-
out ./= tmp
10+
max_ = maximum(x; dims)
11+
diff = exp.(x .- max_)
12+
@trace if all(isfinite, max_)
13+
@. out = diff
14+
else
15+
@. out = ifelse(isinf(max_), ifelse(isinf(x), T(1), T(0)), diff)
16+
end
17+
out ./= sum(out; dims)
2318
return out
2419
end
2520

2621
function NNlib.logsoftmax!(out::AnyTracedRArray{T}, x::AbstractArray; dims=1) where {T}
27-
max_ = NNlib.fast_maximum(x; dims)
28-
# XXX: Once reverse mode of if is properly supported, we can make it @trace
29-
# inf_num = TracedUtils.promote_to(TracedRNumber{T}, Inf)
30-
# zero_num = TracedUtils.promote_to(TracedRNumber{T}, 0)
31-
# @trace if all(isfinite, max_)
32-
@. out = x - max_
33-
# else
34-
# cond = max_ .== Inf
35-
# true_pred = ifelse.(x .== Inf, zero_num, -inf_num)
36-
# @. out = ifelse(cond, true_pred, x - max_)
37-
# end
38-
@fastmath log_ = log.(sum(exp, out; dims))
39-
out .-= log_
22+
max_ = maximum(x; dims)
23+
diff = x .- max_
24+
@trace if all(isfinite, max_)
25+
@. out = diff
26+
else
27+
@. out = ifelse(isinf(max_), ifelse(isinf(x), T(0), -T(Inf)), diff)
28+
end
29+
out .-= log.(sum(exp, out; dims))
4030
return out
4131
end
4232

@@ -111,6 +101,10 @@ function overloaded_conv!(
111101
rhs_dilation=collect(dilation),
112102
feature_group_count,
113103
batch_group_count=1,
104+
precision_config=MLIR.IR.Attribute([
105+
MLIR.IR.Attribute(Reactant.CONVOLUTION_PRECISION[]),
106+
MLIR.IR.Attribute(Reactant.CONVOLUTION_PRECISION[]),
107+
]),
114108
)
115109
set_mlir_data!(y, Reactant.MLIR.IR.result(conv))
116110
return y
@@ -206,6 +200,10 @@ function overloaded_∇conv_filter!(
206200
rhs_dilation=collect(stride),
207201
feature_group_count,
208202
batch_group_count,
203+
precision_config=MLIR.IR.Attribute([
204+
MLIR.IR.Attribute(Reactant.CONVOLUTION_PRECISION[]),
205+
MLIR.IR.Attribute(Reactant.CONVOLUTION_PRECISION[]),
206+
]),
209207
)
210208
set_mlir_data!(dw, MLIR.IR.result(conv))
211209

@@ -326,6 +324,10 @@ function overloaded_∇conv_data!(
326324
dimension_numbers,
327325
feature_group_count,
328326
batch_group_count=1,
327+
precision_config=MLIR.IR.Attribute([
328+
MLIR.IR.Attribute(Reactant.CONVOLUTION_PRECISION[]),
329+
MLIR.IR.Attribute(Reactant.CONVOLUTION_PRECISION[]),
330+
]),
329331
)
330332
set_mlir_data!(dx, MLIR.IR.result(conv))
331333

src/Configuration.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ScopedValues: ScopedValues, ScopedValue
22

33
export with_config
4-
export DotGeneralAlgorithmPreset, DotGeneralPrecision, DotGeneralAlgorithm
4+
export DotGeneralAlgorithmPreset, PrecisionConfig, DotGeneralAlgorithm
55

66
"""
77
with_config(f; kwargs...)
@@ -27,12 +27,15 @@ scope will use the provided values.
2727
[`DotGeneralAlgorithm`](@ref) or [`DotGeneralAlgorithmPreset`](@ref). Defaults to
2828
`DotGeneralAlgorithmPreset.DEFAULT`.
2929
- `dot_general_precision`: Precision for `stablehlo.dot_general`. Can be `nothing`,
30-
or [`DotGeneralPrecision`](@ref). Defaults to `DotGeneralPrecision.DEFAULT`.
30+
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.
31+
- `convolution_precision`: Precision for `stablehlo.convolution`. Can be `nothing`,
32+
or [`PrecisionConfig`](@ref). Defaults to `PrecisionConfig.DEFAULT`.
3133
"""
3234
function with_config(
3335
f;
3436
dot_general_algorithm=missing,
3537
dot_general_precision=missing,
38+
convolution_precision=missing,
3639
lower_partialsort_to_approx_top_k=missing,
3740
fallback_approx_top_k_lowering=missing,
3841
)
@@ -41,6 +44,8 @@ function with_config(
4144
(config_vars = (config_vars..., DOT_GENERAL_ALGORITHM => dot_general_algorithm))
4245
dot_general_precision !== missing &&
4346
(config_vars = (config_vars..., DOT_GENERAL_PRECISION => dot_general_precision))
47+
convolution_precision !== missing &&
48+
(config_vars = (config_vars..., CONVOLUTION_PRECISION => convolution_precision))
4449
lower_partialsort_to_approx_top_k !== missing && (
4550
config_vars = (
4651
config_vars...,
@@ -63,7 +68,7 @@ const FALLBACK_APPROX_TOP_K_LOWERING = ScopedValue(true)
6368

6469
# DotGeneral Attributes Configuration
6570
"""
66-
DotGeneralPrecision
71+
PrecisionConfig
6772
6873
Controls the `precision_config` for `stablehlo.dot_general`. Valid values are:
6974
@@ -73,26 +78,34 @@ Controls the `precision_config` for `stablehlo.dot_general`. Valid values are:
7378
7479
The following functions are available:
7580
76-
`MLIR.IR.Attribute(precision::DotGeneralPrecision.T)`
81+
`MLIR.IR.Attribute(precision::PrecisionConfig.T)`
7782
"""
78-
@enumx DotGeneralPrecision begin
83+
@enumx PrecisionConfig begin
7984
DEFAULT
8085
HIGH
8186
HIGHEST
8287
end
8388

89+
Base.@deprecate_binding DotGeneralPrecision PrecisionConfig
90+
8491
const DOT_GENERAL_PRECISION = ScopedValue{
85-
Union{DotGeneralPrecision.T,Nothing,Tuple{DotGeneralPrecision.T,DotGeneralPrecision.T}}
92+
Union{PrecisionConfig.T,Nothing,Tuple{PrecisionConfig.T,PrecisionConfig.T}}
93+
}(
94+
PrecisionConfig.DEFAULT
95+
)
96+
97+
const CONVOLUTION_PRECISION = ScopedValue{
98+
Union{PrecisionConfig.T,Nothing,Tuple{PrecisionConfig.T,PrecisionConfig.T}}
8699
}(
87-
DotGeneralPrecision.DEFAULT
100+
PrecisionConfig.DEFAULT
88101
)
89102

90-
function MLIR.IR.Attribute(precision::DotGeneralPrecision.T)
91-
precision_str = if precision == DotGeneralPrecision.DEFAULT
103+
function MLIR.IR.Attribute(precision::PrecisionConfig.T)
104+
precision_str = if precision == PrecisionConfig.DEFAULT
92105
"DEFAULT"
93-
elseif precision == DotGeneralPrecision.HIGH
106+
elseif precision == PrecisionConfig.HIGH
94107
"HIGH"
95-
elseif precision == DotGeneralPrecision.HIGHEST
108+
elseif precision == PrecisionConfig.HIGHEST
96109
"HIGHEST"
97110
end
98111
return MLIR.IR.Attribute(

src/Overlay.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,15 @@ end
165165

166166
@reactant_overlay @noinline function Base._all(f, x::AbstractArray{T}, dims) where {T}
167167
if T <: TracedRNumber && T !== Union{}
168-
return TracedRArrayOverrides.overloaded_all(f, x, dims)
168+
return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims)
169169
else
170170
return Base.inferencebarrier(Base._all)(f, x, dims)
171171
end
172172
end
173173

174174
@reactant_overlay @noinline function Base.any(f, x::AbstractArray{T}, dims) where {T}
175175
if T <: TracedRNumber && T !== Union{}
176-
return TracedRArrayOverrides.overloaded_any(f, x, dims)
176+
return TracedRArrayOverrides.overloaded_mapreduce(f, |, x; dims)
177177
else
178178
return Base.inferencebarrier(Base.any)(f, x, dims)
179179
end

0 commit comments

Comments
 (0)