Skip to content

Commit 0e97f10

Browse files
feat: add support for Metal (#120)
1 parent a595bdb commit 0e97f10

File tree

6 files changed

+65
-10
lines changed

6 files changed

+65
-10
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ classification models. TEI enables high-performance extraction for the most popu
4545
Ember, GTE and E5. TEI implements many features such as:
4646

4747
* No model graph compilation step
48+
* Metal support for local execution on Macs
4849
* Small docker images and fast boot times. Get ready for true serverless!
4950
* Token based dynamic batching
5051
* Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention),
@@ -372,7 +373,7 @@ Then run:
372373
# On x86
373374
cargo install --path router -F candle -F mkl
374375
# On M1 or M2
375-
cargo install --path router -F candle -F accelerate
376+
cargo install --path router -F candle -F metal
376377
```
377378

378379
You can now launch Text Embeddings Inference on CPU with:

backends/candle/src/alibi.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ pub fn build_alibi_tensor(
5252
device: &Device,
5353
dtype: DType,
5454
) -> Result<Tensor> {
55-
let context_positions = Tensor::arange(0.0, num_positions as f64, device)?.unsqueeze(1)?;
56-
let memory_positions = Tensor::arange(0.0, num_positions as f64, device)?.unsqueeze(0)?;
55+
let context_positions =
56+
Tensor::arange(0.0, num_positions as f64, &Device::Cpu)?.unsqueeze(1)?;
57+
let memory_positions = Tensor::arange(0.0, num_positions as f64, &Device::Cpu)?.unsqueeze(0)?;
5758

5859
let relative_positions = memory_positions.broadcast_sub(&context_positions)?.abs()?;
5960
// [num_heads, num_positions, num_positions]
@@ -63,13 +64,17 @@ pub fn build_alibi_tensor(
6364
.expand((num_heads, num_positions, num_positions))?;
6465

6566
// [num_heads, 1, 1]
66-
let slopes =
67-
(Tensor::from_vec(alibi_head_slopes(num_heads), (num_heads, 1, 1), device)? * -1_f64)?;
67+
let slopes = (Tensor::from_vec(
68+
alibi_head_slopes(num_heads),
69+
(num_heads, 1, 1),
70+
&Device::Cpu,
71+
)? * -1_f64)?;
6872

6973
// [num_heads, num_positions, num_positions]
7074
let alibi = relative_positions.broadcast_mul(&slopes)?;
7175

7276
alibi
7377
.reshape((1, num_heads, num_positions, num_positions))?
74-
.to_dtype(dtype)
78+
.to_dtype(dtype)?
79+
.to_device(device)
7580
}

backends/candle/src/models/bert.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ impl BertModel {
506506
input_ids.push(batch.input_ids[j]);
507507
type_ids.push(batch.token_type_ids[j]);
508508
position_ids.push(batch.position_ids[j]);
509-
attention_mask.push(1.0);
509+
attention_mask.push(1.0_f32);
510510
attention_bias.push(0.0);
511511
}
512512

@@ -519,7 +519,7 @@ impl BertModel {
519519
input_ids.push(0);
520520
type_ids.push(0);
521521
position_ids.push(0);
522-
attention_mask.push(0.0);
522+
attention_mask.push(0.0_f32);
523523
attention_bias.push(f32::NEG_INFINITY);
524524
}
525525
}

backends/candle/src/models/jina.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ impl JinaBertModel {
440440
input_ids.push(batch.input_ids[j]);
441441
type_ids.push(batch.token_type_ids[j]);
442442
position_ids.push(batch.position_ids[j]);
443-
attention_mask.push(1.0);
443+
attention_mask.push(1.0_f32);
444444
attention_bias.push(0.0);
445445
}
446446

@@ -453,7 +453,7 @@ impl JinaBertModel {
453453
input_ids.push(0);
454454
type_ids.push(0);
455455
position_ids.push(0);
456-
attention_mask.push(0.0);
456+
attention_mask.push(0.0_f32);
457457
attention_bias.push(f32::NEG_INFINITY);
458458
}
459459
}

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
- sections:
1010
- local: local_cpu
1111
title: Using TEI locally with CPU
12+
- local: local_metal
13+
title: Using TEI locally with Metal
1214
- local: local_gpu
1315
title: Using TEI locally with GPU
1416
- local: private_models

docs/source/en/local_metal.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Using TEI locally with Metal
18+
19+
You can install `text-embeddings-inference` locally to run it on your own Mac with Metal support.
20+
Here are the step-by-step instructions for installation:
21+
22+
## Step 1: Install Rust
23+
24+
[Install Rust]((https://rustup.rs/) on your machine by run the following in your terminal, then following the instructions:
25+
26+
```shell
27+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
28+
```
29+
30+
## Step 2: Install with Metal support
31+
32+
```shell
33+
cargo install --path router -F candle -F metal
34+
```
35+
36+
## Step 3: Launch Text Embeddings Inference
37+
38+
Once the installation is successfully complete, you can launch Text Embeddings Inference with Metal with the following command:
39+
40+
```shell
41+
model=BAAI/bge-large-en-v1.5
42+
revision=refs/pr/5
43+
44+
text-embeddings-router --model-id $model --revision $revision --port 8080
45+
```
46+
47+
Now you are ready to use `text-embeddings-inference` locally on your machine.

0 commit comments

Comments
 (0)