Skip to content

Commit 4ea302f

Browse files
committed
feat(postgres): allow to put SQL type in target states
1 parent 844739e commit 4ea302f

File tree

1 file changed

+63
-22
lines changed

1 file changed

+63
-22
lines changed

src/ops/targets/postgres.rs

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use sqlx::postgres::types::PgRange;
1515
use std::ops::Bound;
1616

1717
#[derive(Debug, Deserialize)]
18-
pub struct Spec {
18+
struct Spec {
1919
database: Option<spec::AuthEntryReference<DatabaseConnectionSpec>>,
2020
table_name: Option<String>,
2121
schema: Option<String>,
@@ -130,7 +130,7 @@ fn bind_value_field<'arg>(
130130
Ok(())
131131
}
132132

133-
pub struct ExportContext {
133+
struct ExportContext {
134134
db_ref: Option<spec::AuthEntryReference<DatabaseConnectionSpec>>,
135135
db_pool: PgPool,
136136
key_fields_schema: Box<[FieldSchema]>,
@@ -254,7 +254,7 @@ impl ExportContext {
254254
struct TargetFactory;
255255

256256
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
257-
pub struct TableId {
257+
struct TableId {
258258
#[serde(skip_serializing_if = "Option::is_none")]
259259
database: Option<spec::AuthEntryReference<DatabaseConnectionSpec>>,
260260
#[serde(skip_serializing_if = "Option::is_none")]
@@ -277,9 +277,44 @@ impl std::fmt::Display for TableId {
277277
}
278278

279279
#[derive(Debug, Clone, Serialize, Deserialize)]
280-
pub struct SetupState {
280+
#[serde(untagged)]
281+
enum ColumnType {
282+
ValueType(ValueType),
283+
PostgresType(String),
284+
}
285+
286+
impl ColumnType {
287+
fn uses_pgvector(&self) -> bool {
288+
match self {
289+
ColumnType::ValueType(ValueType::Basic(BasicValueType::Vector(vec_schema))) => {
290+
convertible_to_pgvector(vec_schema)
291+
}
292+
ColumnType::PostgresType(pg_type) => {
293+
pg_type.starts_with("vector(") || pg_type.starts_with("halfvec(")
294+
}
295+
_ => false,
296+
}
297+
}
298+
299+
fn to_column_type_sql<'a>(&'a self) -> Cow<'a, str> {
300+
match self {
301+
ColumnType::ValueType(v) => Cow::Owned(to_column_type_sql(v)),
302+
ColumnType::PostgresType(pg_type) => Cow::Borrowed(pg_type),
303+
}
304+
}
305+
}
306+
307+
impl PartialEq for ColumnType {
308+
fn eq(&self, other: &Self) -> bool {
309+
self.to_column_type_sql() == other.to_column_type_sql()
310+
}
311+
}
312+
impl Eq for ColumnType {}
313+
314+
#[derive(Debug, Clone, Serialize, Deserialize)]
315+
struct SetupState {
281316
#[serde(flatten)]
282-
columns: TableColumnsSchema<ValueType>,
317+
columns: TableColumnsSchema<ColumnType>,
283318

284319
vector_indexes: BTreeMap<String, VectorIndexDef>,
285320
}
@@ -295,11 +330,21 @@ impl SetupState {
295330
columns: TableColumnsSchema {
296331
key_columns: key_fields_schema
297332
.iter()
298-
.map(|f| (f.name.clone(), f.value_type.typ.without_attrs()))
333+
.map(|f| {
334+
(
335+
f.name.clone(),
336+
ColumnType::ValueType(f.value_type.typ.without_attrs()),
337+
)
338+
})
299339
.collect(),
300340
value_columns: value_fields_schema
301341
.iter()
302-
.map(|f| (f.name.clone(), f.value_type.typ.without_attrs()))
342+
.map(|f| {
343+
(
344+
f.name.clone(),
345+
ColumnType::ValueType(f.value_type.typ.without_attrs()),
346+
)
347+
})
303348
.collect(),
304349
},
305350
vector_indexes: index_options
@@ -314,12 +359,7 @@ impl SetupState {
314359
self.columns
315360
.value_columns
316361
.iter()
317-
.any(|(_, value)| match &value {
318-
ValueType::Basic(BasicValueType::Vector(vec_schema)) => {
319-
convertible_to_pgvector(vec_schema)
320-
}
321-
_ => false,
322-
})
362+
.any(|(_, t)| t.uses_pgvector())
323363
}
324364
}
325365

@@ -367,27 +407,27 @@ impl<'a> From<&'a SetupState> for Cow<'a, TableColumnsSchema<String>> {
367407
.columns
368408
.key_columns
369409
.iter()
370-
.map(|(k, v)| (k.clone(), to_column_type_sql(v)))
410+
.map(|(k, v)| (k.clone(), v.to_column_type_sql().into_owned()))
371411
.collect(),
372412
value_columns: val
373413
.columns
374414
.value_columns
375415
.iter()
376-
.map(|(k, v)| (k.clone(), to_column_type_sql(v)))
416+
.map(|(k, v)| (k.clone(), v.to_column_type_sql().into_owned()))
377417
.collect(),
378418
})
379419
}
380420
}
381421

382422
#[derive(Debug)]
383-
pub struct TableSetupAction {
423+
struct TableSetupAction {
384424
table_action: TableMainSetupAction<String>,
385425
indexes_to_delete: IndexSet<String>,
386426
indexes_to_create: IndexMap<String, VectorIndexDef>,
387427
}
388428

389429
#[derive(Debug)]
390-
pub struct SetupChange {
430+
struct SetupChange {
391431
create_pgvector_extension: bool,
392432
actions: TableSetupAction,
393433
vector_as_jsonb_columns: Vec<(String, ValueType)>,
@@ -402,7 +442,8 @@ impl SetupChange {
402442
.iter()
403443
.flat_map(|s| {
404444
s.columns.value_columns.iter().filter_map(|(name, schema)| {
405-
if let ValueType::Basic(BasicValueType::Vector(vec_schema)) = schema
445+
if let ColumnType::ValueType(value_type) = schema
446+
&& let ValueType::Basic(BasicValueType::Vector(vec_schema)) = value_type
406447
&& !convertible_to_pgvector(vec_schema)
407448
{
408449
let is_touched = match &table_action.table_upsertion {
@@ -415,7 +456,7 @@ impl SetupChange {
415456
None => false,
416457
};
417458
if is_touched {
418-
Some((name.clone(), schema.clone()))
459+
Some((name.clone(), value_type.clone()))
419460
} else {
420461
None
421462
}
@@ -791,19 +832,19 @@ impl TargetFactoryBase for TargetFactory {
791832
////////////////////////////////////////////////////////////
792833

793834
#[derive(Debug, Clone, Serialize, Deserialize)]
794-
pub struct SqlCommandSpec {
835+
struct SqlCommandSpec {
795836
name: String,
796837
setup_sql: String,
797838
teardown_sql: Option<String>,
798839
}
799840

800841
#[derive(Debug, Clone, Serialize, Deserialize)]
801-
pub struct SqlCommandState {
842+
struct SqlCommandState {
802843
setup_sql: String,
803844
teardown_sql: Option<String>,
804845
}
805846

806-
pub struct SqlCommandSetupChange {
847+
struct SqlCommandSetupChange {
807848
db_pool: PgPool,
808849
setup_sql_to_run: Option<String>,
809850
teardown_sql_to_run: IndexSet<String>,

0 commit comments

Comments
 (0)