Skip to content

Commit cb9de7a

Browse files
vrdn-23alvarobartt
andauthored
Default HiddenAct::Gelu to GeLU + tanh in favour of GeLU erf (#753)
Co-authored-by: Alvaro Bartolome <[email protected]>
1 parent 02f60f0 commit cb9de7a

File tree

3 files changed

+99
-84
lines changed

3 files changed

+99
-84
lines changed

backends/candle/src/layers/linear.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use serde::Deserialize;
55
#[derive(Debug, Deserialize, PartialEq, Clone)]
66
#[serde(rename_all = "lowercase")]
77
pub enum HiddenAct {
8-
#[serde(alias = "gelu_pytorch_tanh")]
8+
// NOTE: `GeluErf` is excluded due to incompatibility with cuBLASLt, as only GeLU + tanh
9+
// approximation is implemented due to efficiency, so GeLU is standardized to tanh approx. with
10+
// slight numerical deviation from GeLU erf (neglible on inference quality)
11+
#[serde(alias = "gelu_new", alias = "gelu_pytorch_tanh")]
912
Gelu,
1013
Relu,
1114
Silu,

backends/candle/src/lib.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,19 @@ impl CandleBackend {
180180
let config: String = std::fs::read_to_string(model_path.join("config.json"))
181181
.context("Unable to read config file")
182182
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
183-
let config: Config = serde_json::from_str(&config)
183+
184+
let config_value: serde_json::Value = serde_json::from_str(&config)
185+
.context("Unable to parse config.json")
186+
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
187+
188+
if let Some(hidden_act) = config_value.get("hidden_act").and_then(|v| v.as_str()) {
189+
if hidden_act == "gelu" {
190+
// NOTE: https://github.com/huggingface/text-embeddings-inference/pull/753
191+
tracing::warn!("The `config.json` contains `hidden_act=gelu` and GeLU + tanh approximation will be used instead of exact GeLU (aka. GeLU erf), which might lead to subtle differences with Transformers or Sentence Transformers outputs which use exact GeLU when `hidden_act=gelu`, unless specified otherwise. GeLU + tanh is more efficient and more consistent across devices (e.g., cuBLASLt comes with fused GeLU + tanh), and will have negligible impact on the inference quality.");
192+
}
193+
}
194+
195+
let config: Config = serde_json::from_value(config_value)
184196
.context("Model is not supported")
185197
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
186198

backends/candle/tests/snapshots/test_bert__emotions_batch.snap

Lines changed: 82 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,87 +2,87 @@
22
source: backends/candle/tests/test_bert.rs
33
expression: predictions_batch
44
---
5-
- - -6.548559
6-
- -6.302024
7-
- -4.8671727
8-
- -3.9600255
9-
- -4.6329865
10-
- -6.2816987
11-
- -6.069644
12-
- -5.7742686
13-
- -6.9259467
14-
- -6.1909447
15-
- -5.67395
5+
- - -6.5485673
6+
- -6.3020196
7+
- -4.86717
8+
- -3.9600184
9+
- -4.632993
10+
- -6.2817054
11+
- -6.069636
12+
- -5.7742705
13+
- -6.925953
14+
- -6.190939
15+
- -5.6739373
1616
- -6.1698227
17-
- -7.513461
18-
- -6.865867
19-
- -7.186479
20-
- -7.128109
21-
- -8.210709
22-
- -7.0171394
23-
- -7.1321163
24-
- -8.533409
25-
- -6.2294865
26-
- -8.742306
27-
- -5.7792044
28-
- -8.657227
29-
- -8.258305
30-
- -6.64832
31-
- -7.4060283
32-
- 3.046496
33-
- - -5.8167515
34-
- -6.6119466
35-
- -5.2771955
36-
- -2.6306503
37-
- -4.6419163
38-
- -5.579778
39-
- -5.797174
40-
- -6.0305815
41-
- -5.8720746
42-
- 0.45377323
43-
- -3.0235887
44-
- -5.3944407
45-
- -5.186683
46-
- -6.2649117
47-
- -6.1962767
48-
- -6.97937
49-
- -5.5674877
50-
- -5.521044
51-
- -5.8899207
52-
- -4.8699703
53-
- -5.6259933
54-
- -7.6109924
55-
- -4.3881936
56-
- -6.039008
57-
- -4.934696
58-
- -0.6715916
59-
- -6.399376
60-
- -2.4499295
61-
- - -6.548559
62-
- -6.302024
63-
- -4.8671727
64-
- -3.9600255
65-
- -4.6329865
66-
- -6.2816987
67-
- -6.069644
68-
- -5.7742686
69-
- -6.9259467
70-
- -6.1909447
71-
- -5.67395
17+
- -7.5134573
18+
- -6.8658743
19+
- -7.1864815
20+
- -7.128115
21+
- -8.2107115
22+
- -7.017146
23+
- -7.132131
24+
- -8.533407
25+
- -6.229486
26+
- -8.742311
27+
- -5.7792006
28+
- -8.65723
29+
- -8.258308
30+
- -6.648321
31+
- -7.406026
32+
- 3.0464942
33+
- - -5.816747
34+
- -6.611947
35+
- -5.2771983
36+
- -2.6306484
37+
- -4.6419153
38+
- -5.5797825
39+
- -5.7971735
40+
- -6.030578
41+
- -5.872076
42+
- 0.45378062
43+
- -3.0235896
44+
- -5.3944383
45+
- -5.18668
46+
- -6.264913
47+
- -6.196284
48+
- -6.9793677
49+
- -5.567489
50+
- -5.5210495
51+
- -5.889915
52+
- -4.8699794
53+
- -5.625993
54+
- -7.6109934
55+
- -4.388194
56+
- -6.0390115
57+
- -4.934693
58+
- -0.6715966
59+
- -6.3993735
60+
- -2.4499245
61+
- - -6.5485673
62+
- -6.3020196
63+
- -4.86717
64+
- -3.9600184
65+
- -4.632993
66+
- -6.2817054
67+
- -6.069636
68+
- -5.7742705
69+
- -6.925953
70+
- -6.190939
71+
- -5.6739373
7272
- -6.1698227
73-
- -7.513461
74-
- -6.865867
75-
- -7.186479
76-
- -7.128109
77-
- -8.210709
78-
- -7.0171394
79-
- -7.1321163
80-
- -8.533409
81-
- -6.2294865
82-
- -8.742306
83-
- -5.7792044
84-
- -8.657227
85-
- -8.258305
86-
- -6.64832
87-
- -7.4060283
88-
- 3.046496
73+
- -7.5134573
74+
- -6.8658743
75+
- -7.1864815
76+
- -7.128115
77+
- -8.2107115
78+
- -7.017146
79+
- -7.132131
80+
- -8.533407
81+
- -6.229486
82+
- -8.742311
83+
- -5.7792006
84+
- -8.65723
85+
- -8.258308
86+
- -6.648321
87+
- -7.406026
88+
- 3.0464942

0 commit comments

Comments
 (0)