diff --git a/src/base/spec.rs b/src/base/spec.rs index 94cb295ad..e632765bc 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use serde::{Deserialize, Serialize}; -use super::schema::FieldSchema; +use super::schema::{EnrichedValueType, FieldSchema}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "kind")] @@ -85,7 +85,8 @@ pub struct FieldMapping { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LiteralMapping { +pub struct ConstantMapping { + pub schema: EnrichedValueType, pub value: serde_json::Value, } @@ -103,7 +104,7 @@ pub struct StructMapping { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "kind")] pub enum ValueMapping { - Literal(LiteralMapping), + Constant(ConstantMapping), Field(FieldMapping), Struct(StructMapping), // TODO: Add support for collections @@ -124,7 +125,7 @@ impl ValueMapping { impl std::fmt::Display for ValueMapping { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ValueMapping::Literal(v) => write!( + ValueMapping::Constant(v) => write!( f, "{}", serde_json::to_string(&v.value) diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 8e4508daf..e5db86959 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -11,11 +11,7 @@ use crate::setup::{ use crate::utils::fingerprint::Fingerprinter; use crate::{ api_bail, api_error, - base::{ - schema::*, - spec::*, - value::{self, *}, - }, + base::{schema::*, spec::*, value}, ops::{interface::*, registry::*}, utils::immutable::RefList, }; @@ -497,30 +493,9 @@ fn analyze_value_mapping( scopes: RefList<'_, &'_ ExecutionScope<'_>>, ) -> Result<(AnalyzedValueMapping, EnrichedValueType)> { let result = match value_mapping { - ValueMapping::Literal(v) => { - let (value_type, basic_value) = match &v.value { - serde_json::Value::String(s) => { - (BasicValueType::Str, BasicValue::Str(Arc::from(s.as_str()))) - } - serde_json::Value::Number(n) => ( - BasicValueType::Float64, - BasicValue::Float64( - n.as_f64().ok_or_else(|| anyhow!("Invalid number: {}", n))?, - ), - ), - serde_json::Value::Bool(b) => (BasicValueType::Bool, BasicValue::Bool(*b)), - _ => bail!("Unsupported value type: {}", v.value), - }; - ( - AnalyzedValueMapping::Literal { - value: value::Value::Basic(basic_value), - }, - EnrichedValueType { - typ: ValueType::Basic(value_type), - nullable: false, - attrs: Default::default(), - }, - ) + ValueMapping::Constant(v) => { + let value = value::Value::from_json(v.value.clone(), &v.schema.typ)?; + (AnalyzedValueMapping::Constant { value }, v.schema.clone()) } ValueMapping::Field(v) => { diff --git a/src/builder/flow_builder.rs b/src/builder/flow_builder.rs index 48340a848..f0558a03a 100644 --- a/src/builder/flow_builder.rs +++ b/src/builder/flow_builder.rs @@ -158,19 +158,19 @@ impl DataScopeRef { #[pyclass] #[derive(Debug, Clone)] pub struct DataType { - typ: schema::EnrichedValueType, + schema: schema::EnrichedValueType, } impl From for DataType { - fn from(typ: schema::EnrichedValueType) -> Self { - Self { typ } + fn from(schema: schema::EnrichedValueType) -> Self { + Self { schema } } } #[pymethods] impl DataType { pub fn __str__(&self) -> String { - format!("{}", self.typ) + format!("{}", self.schema) } pub fn __repr__(&self) -> String { @@ -201,7 +201,7 @@ impl DataSlice { } pub fn field(&self, field_name: &str) -> PyResult> { - let field_schema = match &self.data_type.typ.typ { + let field_schema = match &self.data_type.schema.typ { schema::ValueType::Struct(struct_type) => { match struct_type.fields.iter().find(|f| f.name == field_name) { Some(field) => field, @@ -232,7 +232,7 @@ impl DataSlice { .map(|f| f.spec.clone()) .ok_or_else(|| PyException::new_err(format!("field {} not found", field_name)))?, - spec::ValueMapping::Literal { .. } => { + spec::ValueMapping::Constant { .. } => { return Err(PyException::new_err( "field access not supported for literal", )) @@ -277,7 +277,7 @@ impl std::fmt::Display for DataSlice { write!( f, "DataSlice({}; {} {}) ", - self.data_type.typ, self.scope, self.value + self.data_type.schema, self.scope, self.value )?; Ok(()) } @@ -420,6 +420,24 @@ impl FlowBuilder { Ok(result) } + pub fn constant<'py>( + &self, + value_type: py::Pythonized, + value: Bound<'py, PyAny>, + ) -> PyResult { + let schema = value_type.into_inner(); + let value = py::value_from_py_object(&schema.typ, &value)?; + let slice = DataSlice { + scope: self.root_data_scope_ref.clone(), + value: Arc::new(spec::ValueMapping::Constant(spec::ConstantMapping { + schema: schema.clone(), + value: serde_json::to_value(value).into_py_result()?, + })), + data_type: schema.into(), + }; + Ok(slice) + } + pub fn add_direct_input( &mut self, name: String, @@ -533,7 +551,7 @@ impl FlowBuilder { .into_iter() .map(|(name, ds)| FieldSchema { name, - value_type: ds.data_type.typ, + value_type: ds.data_type.schema, }) .collect(), ), @@ -600,7 +618,7 @@ impl FlowBuilder { scope: None, field_path: spec::FieldPath(vec![field_name.to_string()]), })), - data_type: DataType { typ: field_type }, + data_type: DataType { schema: field_type }, })) } diff --git a/src/builder/plan.rs b/src/builder/plan.rs index 8dc94a568..2334bf24d 100644 --- a/src/builder/plan.rs +++ b/src/builder/plan.rs @@ -45,7 +45,7 @@ pub struct AnalyzedStructMapping { #[derive(Debug, Clone, Serialize)] #[serde(tag = "kind")] pub enum AnalyzedValueMapping { - Literal { value: value::Value }, + Constant { value: value::Value }, Field(AnalyzedFieldReference), Struct(AnalyzedStructMapping), } diff --git a/src/execution/evaluator.rs b/src/execution/evaluator.rs index f9920b481..9222c38d3 100644 --- a/src/execution/evaluator.rs +++ b/src/execution/evaluator.rs @@ -266,7 +266,7 @@ fn assemble_value( scoped_entries: RefList<'_, &ScopeEntry<'_>>, ) -> value::Value { match value_mapping { - AnalyzedValueMapping::Literal { value } => value.clone(), + AnalyzedValueMapping::Constant { value } => value.clone(), AnalyzedValueMapping::Field(field_ref) => scoped_entries .headn(field_ref.scope_up_level as usize) .unwrap() diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index 40e18634d..c8dda8abd 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -24,8 +24,20 @@ pub struct ResolvedOpArg { pub idx: usize, } -impl ResolvedOpArg { - pub fn expect_type(self, expected_type: &ValueType) -> Result { +pub trait ResolvedOpArgExt: Sized { + type ValueType; + type ValueRef<'a>; + + fn expect_type(self, expected_type: &ValueType) -> Result; + fn value<'a>(&self, args: &'a Vec) -> Result>; + fn take_value(&self, args: &mut Vec) -> Result; +} + +impl ResolvedOpArgExt for ResolvedOpArg { + type ValueType = value::Value; + type ValueRef<'a> = &'a value::Value; + + fn expect_type(self, expected_type: &ValueType) -> Result { if &self.typ.typ != expected_type { api_bail!( "Expected argument `{}` to be of type `{}`, got `{}`", @@ -37,7 +49,7 @@ impl ResolvedOpArg { Ok(self) } - pub fn value<'a>(&self, args: &'a Vec) -> Result<&'a value::Value> { + fn value<'a>(&self, args: &'a Vec) -> Result<&'a value::Value> { if self.idx >= args.len() { api_bail!( "Two few arguments, {} provided, expected at least {} for `{}`", @@ -49,7 +61,7 @@ impl ResolvedOpArg { Ok(&args[self.idx]) } - pub fn take_value(&self, args: &mut Vec) -> Result { + fn take_value(&self, args: &mut Vec) -> Result { if self.idx >= args.len() { api_bail!( "Two few arguments, {} provided, expected at least {} for `{}`", @@ -62,6 +74,23 @@ impl ResolvedOpArg { } } +impl ResolvedOpArgExt for Option { + type ValueType = Option; + type ValueRef<'a> = Option<&'a value::Value>; + + fn expect_type(self, expected_type: &ValueType) -> Result { + self.map(|arg| arg.expect_type(expected_type)).transpose() + } + + fn value<'a>(&self, args: &'a Vec) -> Result> { + self.as_ref().map(|arg| arg.value(args)).transpose() + } + + fn take_value(&self, args: &mut Vec) -> Result> { + self.as_ref().map(|arg| arg.take_value(args)).transpose() + } +} + pub struct OpArgsResolver<'a> { args: &'a [OpArgSchema], num_positional_args: usize,