Skip to content

Commit 1e90b98

Browse files
authored
feat: add sentence transformers support (#45)
- Add support for sentence transformers - Add `mean_pool` tensor operation
1 parent 3b4a68d commit 1e90b98

File tree

29 files changed

+578
-31
lines changed

29 files changed

+578
-31
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ licenses:
1919
@echo "Generating licenses..."
2020
@cargo bundle-licenses --format yaml --output THIRDPARTY.yml
2121

22+
.PHONY:
23+
clippy:
24+
cargo clippy --fix --all-features --allow-dirty
25+
2226
.PHONY: pre-commit
2327
pre-commit:
2428
@uv run pre-commit run --all-files

encoderfile/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
3232
"proto/embedding.proto",
3333
"proto/sequence_classification.proto",
3434
"proto/token_classification.proto",
35+
"proto/sentence_embedding.proto",
3536
],
3637
&[
3738
"proto/embedding",
3839
"proto/sequence_classification",
3940
"proto/token_classification",
41+
"proto/sentence_embedding",
4042
],
4143
)?;
4244

encoderfile/proto/metadata.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ enum ModelType {
1515
EMBEDDING = 1;
1616
SEQUENCE_CLASSIFICATION = 2;
1717
TOKEN_CLASSIFICATION = 3;
18+
SENTENCE_EMBEDDING = 4;
1819
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
syntax = "proto3";
2+
3+
package encoderfile.sentence_embedding;
4+
5+
import "proto/token.proto";
6+
import "proto/metadata.proto";
7+
8+
service SentenceEmbeddingInference {
9+
rpc Predict(SentenceEmbeddingRequest) returns (SentenceEmbeddingResponse);
10+
rpc GetModelMetadata(encoderfile.metadata.GetModelMetadataRequest) returns (encoderfile.metadata.GetModelMetadataResponse);
11+
}
12+
13+
message SentenceEmbeddingRequest {
14+
repeated string inputs = 1;
15+
map<string, string> metadata = 3;
16+
}
17+
18+
message SentenceEmbeddingResponse {
19+
// len(embeddings) == len(inputs)
20+
repeated SentenceEmbedding results = 1;
21+
string model_id = 2;
22+
map<string, string> metadata = 3;
23+
}
24+
25+
message SentenceEmbedding {
26+
repeated float embedding = 1;
27+
}

encoderfile/src/cli.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use crate::{
22
common::{
3-
EmbeddingRequest, ModelType, SequenceClassificationRequest, TokenClassificationRequest,
3+
EmbeddingRequest, ModelType, SentenceEmbeddingRequest, SequenceClassificationRequest,
4+
TokenClassificationRequest,
45
},
56
runtime::AppState,
67
server::{run_grpc, run_http, run_mcp},
7-
services::{embedding, sequence_classification, token_classification},
8+
services::{embedding, sentence_embedding, sequence_classification, token_classification},
89
};
910
use anyhow::Result;
1011
use clap_derive::{Parser, Subcommand, ValueEnum};
@@ -145,6 +146,11 @@ impl Commands {
145146

146147
generate_cli_route!(request, token_classification, format, out_dir, state)
147148
}
149+
ModelType::SentenceEmbedding => {
150+
let request = SentenceEmbeddingRequest { inputs, metadata };
151+
152+
generate_cli_route!(request, sentence_embedding, format, out_dir, state)
153+
}
148154
}
149155
}
150156
Commands::Mcp { hostname, port } => {

encoderfile/src/common/mod.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
pub mod embedding;
2-
pub mod model_metadata;
3-
pub mod model_type;
4-
pub mod sequence_classification;
5-
pub mod token;
6-
pub mod token_classification;
1+
mod embedding;
2+
mod model_metadata;
3+
mod model_type;
4+
mod sentence_embedding;
5+
mod sequence_classification;
6+
mod token;
7+
mod token_classification;
78

89
pub use embedding::*;
910
pub use model_metadata::*;
1011
pub use model_type::*;
12+
pub use sentence_embedding::*;
1113
pub use sequence_classification::*;
1214
pub use token::*;
1315
pub use token_classification::*;

encoderfile/src/common/model_type.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub enum ModelType {
44
Embedding,
55
SequenceClassification,
66
TokenClassification,
7+
SentenceEmbedding,
78
}
89

910
impl From<ModelType> for crate::generated::metadata::ModelType {
@@ -12,6 +13,7 @@ impl From<ModelType> for crate::generated::metadata::ModelType {
1213
ModelType::Embedding => Self::Embedding,
1314
ModelType::SequenceClassification => Self::SequenceClassification,
1415
ModelType::TokenClassification => Self::TokenClassification,
16+
ModelType::SentenceEmbedding => Self::SentenceEmbedding,
1517
}
1618
}
1719
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use schemars::JsonSchema;
2+
use serde::{Deserialize, Serialize};
3+
use std::collections::HashMap;
4+
use utoipa::ToSchema;
5+
6+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)]
7+
pub struct SentenceEmbeddingRequest {
8+
pub inputs: Vec<String>,
9+
#[serde(default)]
10+
pub metadata: Option<HashMap<String, String>>,
11+
}
12+
13+
impl From<crate::generated::sentence_embedding::SentenceEmbeddingRequest>
14+
for SentenceEmbeddingRequest
15+
{
16+
fn from(val: crate::generated::sentence_embedding::SentenceEmbeddingRequest) -> Self {
17+
Self {
18+
inputs: val.inputs,
19+
metadata: Some(val.metadata),
20+
}
21+
}
22+
}
23+
24+
#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)]
25+
pub struct SentenceEmbeddingResponse {
26+
pub results: Vec<SentenceEmbedding>,
27+
pub model_id: String,
28+
#[serde(skip_serializing_if = "Option::is_none")]
29+
pub metadata: Option<HashMap<String, String>>,
30+
}
31+
32+
impl From<SentenceEmbeddingResponse>
33+
for crate::generated::sentence_embedding::SentenceEmbeddingResponse
34+
{
35+
fn from(val: SentenceEmbeddingResponse) -> Self {
36+
Self {
37+
results: val.results.into_iter().map(|i| i.into()).collect(),
38+
model_id: val.model_id,
39+
metadata: val.metadata.unwrap_or_default(),
40+
}
41+
}
42+
}
43+
44+
#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)]
45+
pub struct SentenceEmbedding {
46+
pub embedding: Vec<f32>,
47+
}
48+
49+
impl From<SentenceEmbedding> for crate::generated::sentence_embedding::SentenceEmbedding {
50+
fn from(val: SentenceEmbedding) -> Self {
51+
Self {
52+
embedding: val.embedding,
53+
}
54+
}
55+
}

encoderfile/src/generated/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ pub mod token_classification {
1010
tonic::include_proto!("encoderfile.token_classification");
1111
}
1212

13+
pub mod sentence_embedding {
14+
tonic::include_proto!("encoderfile.sentence_embedding");
15+
}
16+
1317
pub mod token {
1418
tonic::include_proto!("encoderfile.token");
1519
}

encoderfile/src/inference/embedding.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ pub fn embedding<'a>(
2424
.expect("Model does not return tensor of shape [n_batch, n_tokens, hidden_dim]")
2525
.into_owned();
2626

27-
if let Some(transform) = state.transform() {
28-
outputs = transform.postprocess(outputs)?;
29-
}
27+
outputs = state.transform().postprocess(outputs)?;
3028

3129
let embeddings = postprocess(outputs, encodings);
3230

0 commit comments

Comments
 (0)