Skip to content

Commit a98d4f1

Browse files
committed
refactor(qdrant): prepare vector info ahead of time
1 parent 566c150 commit a98d4f1

File tree

1 file changed

+59
-40
lines changed

1 file changed

+59
-40
lines changed

src/ops/storages/qdrant.rs

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,39 +22,30 @@ struct Spec {
2222
api_key: Option<String>,
2323
}
2424

25+
struct FieldInfo {
26+
field_name: String,
27+
is_qdrant_vector: bool,
28+
}
29+
2530
struct ExportContext {
2631
client: Qdrant,
2732
collection_name: String,
28-
value_fields_schema: Vec<FieldSchema>,
29-
all_fields: Vec<FieldSchema>,
33+
fields_info: Vec<FieldInfo>,
3034
}
3135

3236
impl ExportContext {
3337
fn new(
3438
url: String,
3539
collection_name: String,
3640
api_key: Option<String>,
37-
key_fields_schema: Vec<FieldSchema>,
38-
value_fields_schema: Vec<FieldSchema>,
41+
fields_info: Vec<FieldInfo>,
3942
) -> Result<Self> {
40-
let all_fields = key_fields_schema
41-
.iter()
42-
.chain(value_fields_schema.iter())
43-
.cloned()
44-
.collect::<Vec<_>>();
45-
46-
// Hotfix to resolve
47-
// `no process-level CryptoProvider available -- call CryptoProvider::install_default() before this point`
48-
// when using HTTPS URLs.
49-
let _ = rustls::crypto::ring::default_provider().install_default();
50-
5143
Ok(Self {
5244
client: Qdrant::from_url(&url)
5345
.api_key(api_key)
5446
.skip_compatibility_check()
5547
.build()?,
56-
value_fields_schema,
57-
all_fields,
48+
fields_info,
5849
collection_name,
5950
})
6051
}
@@ -63,8 +54,7 @@ impl ExportContext {
6354
let mut points: Vec<PointStruct> = Vec::with_capacity(mutation.upserts.len());
6455
for upsert in mutation.upserts.iter() {
6556
let point_id = key_to_point_id(&upsert.key)?;
66-
let (payload, vectors) =
67-
values_to_payload(&upsert.value.fields, &self.value_fields_schema)?;
57+
let (payload, vectors) = values_to_payload(&upsert.value.fields, &self.fields_info)?;
6858

6959
points.push(PointStruct::new(point_id, vectors, payload));
7060
}
@@ -105,15 +95,42 @@ fn key_to_point_id(key_value: &KeyValue) -> Result<PointId> {
10595
Ok(point_id)
10696
}
10797

98+
fn is_supported_vector_type(typ: &schema::ValueType) -> bool {
99+
match typ {
100+
schema::ValueType::Basic(schema::BasicValueType::Vector(vector_schema)) => {
101+
match &*vector_schema.element_type {
102+
schema::BasicValueType::Float32 => true,
103+
schema::BasicValueType::Float64 => true,
104+
schema::BasicValueType::Int64 => true,
105+
_ => false,
106+
}
107+
}
108+
_ => false,
109+
}
110+
}
111+
112+
fn encode_vector(v: &[BasicValue]) -> Result<Vec<f32>> {
113+
v.iter()
114+
.map(|elem| {
115+
Ok(match elem {
116+
BasicValue::Float32(f) => *f,
117+
BasicValue::Float64(f) => *f as f32,
118+
BasicValue::Int64(i) => *i as f32,
119+
_ => bail!("Unsupported vector type: {:?}", elem.kind()),
120+
})
121+
})
122+
.collect::<Result<Vec<_>>>()
123+
}
124+
108125
fn values_to_payload(
109126
value_fields: &[Value],
110-
schema: &[FieldSchema],
127+
fields_info: &[FieldInfo],
111128
) -> Result<(HashMap<String, QdrantValue>, NamedVectors)> {
112129
let mut payload = HashMap::with_capacity(value_fields.len());
113130
let mut vectors = NamedVectors::default();
114131

115-
for (value, field_schema) in value_fields.iter().zip(schema.iter()) {
116-
let field_name = &field_schema.name;
132+
for (value, field_info) in value_fields.iter().zip(fields_info.iter()) {
133+
let field_name = &field_info.field_name;
117134

118135
match value {
119136
Value::Basic(basic_value) => {
@@ -133,9 +150,12 @@ fn values_to_payload(
133150
BasicValue::TimeDelta(v) => v.to_string().into(),
134151
BasicValue::Json(v) => (**v).clone(),
135152
BasicValue::Vector(v) => {
136-
let vector = convert_to_vector(v.to_vec());
137-
vectors = vectors.add_vector(field_name, vector);
138-
continue;
153+
if field_info.is_qdrant_vector {
154+
let vector = encode_vector(v.as_ref())?;
155+
vectors = vectors.add_vector(field_name, vector);
156+
continue;
157+
}
158+
serde_json::to_value(v)?
139159
}
140160
};
141161
payload.insert(field_name.clone(), json_value.into());
@@ -149,18 +169,6 @@ fn values_to_payload(
149169

150170
Ok((payload, vectors))
151171
}
152-
153-
fn convert_to_vector(v: Vec<BasicValue>) -> Vec<f32> {
154-
v.iter()
155-
.filter_map(|elem| match elem {
156-
BasicValue::Float32(f) => Some(*f),
157-
BasicValue::Float64(f) => Some(*f as f32),
158-
BasicValue::Int64(i) => Some(*i as f32),
159-
_ => None,
160-
})
161-
.collect()
162-
}
163-
164172
#[derive(Default)]
165173
struct Factory {}
166174

@@ -198,9 +206,21 @@ impl StorageFactoryBase for Factory {
198206
Vec<TypedExportDataCollectionBuildOutput<Self>>,
199207
Vec<(String, ())>,
200208
)> {
209+
// Hotfix to resolve
210+
// `no process-level CryptoProvider available -- call CryptoProvider::install_default() before this point`
211+
// when using HTTPS URLs.
212+
let _ = rustls::crypto::ring::default_provider().install_default();
213+
201214
let data_coll_output = data_collections
202215
.into_iter()
203216
.map(|d| {
217+
let mut fields_info = Vec::<FieldInfo>::new();
218+
for field in d.value_fields_schema.iter() {
219+
fields_info.push(FieldInfo {
220+
field_name: field.name.clone(),
221+
is_qdrant_vector: is_supported_vector_type(&field.value_type.typ),
222+
});
223+
}
204224
if d.key_fields_schema.len() != 1 {
205225
api_bail!(
206226
"Expected one primary key field for the point ID. Got {}.",
@@ -212,10 +232,9 @@ impl StorageFactoryBase for Factory {
212232

213233
let export_context = Arc::new(ExportContext::new(
214234
d.spec.grpc_url,
215-
d.spec.collection_name.clone(),
235+
collection_name.clone(),
216236
d.spec.api_key,
217-
d.key_fields_schema,
218-
d.value_fields_schema,
237+
fields_info,
219238
)?);
220239
let executors = async move {
221240
Ok(TypedExportTargetExecutors {

0 commit comments

Comments
 (0)