Skip to content

Commit 2a73c61

Browse files
committed
fix: validate output vector dimension
1 parent 3223c84 commit 2a73c61

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

src/ops/functions/embed_text.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct Spec {
1818
struct Args {
1919
client: Box<dyn LlmEmbeddingClient>,
2020
text: ResolvedOpArg,
21+
expected_output_dimension: usize,
2122
}
2223

2324
struct Executor {
@@ -48,6 +49,23 @@ impl SimpleFunctionExecutor for Executor {
4849
.map(|s| Cow::Borrowed(s.as_str())),
4950
};
5051
let embedding = self.args.client.embed_text(req).await?;
52+
if embedding.embedding.len() != self.args.expected_output_dimension {
53+
if self.spec.output_dimension.is_some() {
54+
api_bail!(
55+
"Expected output dimension {expected} but got {actual} from the embedding API. \
56+
Consider setting `output_dimension` to {actual} or leave it unset to use the default.",
57+
expected = self.args.expected_output_dimension,
58+
actual = embedding.embedding.len()
59+
);
60+
} else {
61+
bail!(
62+
"Expected output dimension {expected} but got {actual} from the embedding API. \
63+
Consider setting `output_dimension` to {actual} as a workaround.",
64+
expected = self.args.expected_output_dimension,
65+
actual = embedding.embedding.len()
66+
)
67+
}
68+
}
5169
Ok(embedding.embedding.into())
5270
}
5371
}
@@ -87,7 +105,14 @@ impl SimpleFunctionFactoryBase for Factory {
87105
dimension: Some(output_dimension as usize),
88106
element_type: Box::new(BasicValueType::Float32),
89107
}));
90-
Ok((Args { client, text }, output_schema))
108+
Ok((
109+
Args {
110+
client,
111+
text,
112+
expected_output_dimension: output_dimension as usize,
113+
},
114+
output_schema,
115+
))
91116
}
92117

93118
async fn build_executor(

0 commit comments

Comments
 (0)