Skip to content

Commit 0cc1da3

Browse files
authored
Robust support for passing constant to functions. (#123)
1 parent fca729a commit 0cc1da3

File tree

6 files changed

+71
-48
lines changed

6 files changed

+71
-48
lines changed

src/base/spec.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::ops::Deref;
22

33
use serde::{Deserialize, Serialize};
44

5-
use super::schema::FieldSchema;
5+
use super::schema::{EnrichedValueType, FieldSchema};
66

77
#[derive(Debug, Clone, Serialize, Deserialize)]
88
#[serde(tag = "kind")]
@@ -85,7 +85,8 @@ pub struct FieldMapping {
8585
}
8686

8787
#[derive(Debug, Clone, Serialize, Deserialize)]
88-
pub struct LiteralMapping {
88+
pub struct ConstantMapping {
89+
pub schema: EnrichedValueType,
8990
pub value: serde_json::Value,
9091
}
9192

@@ -103,7 +104,7 @@ pub struct StructMapping {
103104
#[derive(Debug, Clone, Serialize, Deserialize)]
104105
#[serde(tag = "kind")]
105106
pub enum ValueMapping {
106-
Literal(LiteralMapping),
107+
Constant(ConstantMapping),
107108
Field(FieldMapping),
108109
Struct(StructMapping),
109110
// TODO: Add support for collections
@@ -124,7 +125,7 @@ impl ValueMapping {
124125
impl std::fmt::Display for ValueMapping {
125126
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126127
match self {
127-
ValueMapping::Literal(v) => write!(
128+
ValueMapping::Constant(v) => write!(
128129
f,
129130
"{}",
130131
serde_json::to_string(&v.value)

src/builder/analyzer.rs

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@ use crate::setup::{
1111
use crate::utils::fingerprint::Fingerprinter;
1212
use crate::{
1313
api_bail, api_error,
14-
base::{
15-
schema::*,
16-
spec::*,
17-
value::{self, *},
18-
},
14+
base::{schema::*, spec::*, value},
1915
ops::{interface::*, registry::*},
2016
utils::immutable::RefList,
2117
};
@@ -497,30 +493,9 @@ fn analyze_value_mapping(
497493
scopes: RefList<'_, &'_ ExecutionScope<'_>>,
498494
) -> Result<(AnalyzedValueMapping, EnrichedValueType)> {
499495
let result = match value_mapping {
500-
ValueMapping::Literal(v) => {
501-
let (value_type, basic_value) = match &v.value {
502-
serde_json::Value::String(s) => {
503-
(BasicValueType::Str, BasicValue::Str(Arc::from(s.as_str())))
504-
}
505-
serde_json::Value::Number(n) => (
506-
BasicValueType::Float64,
507-
BasicValue::Float64(
508-
n.as_f64().ok_or_else(|| anyhow!("Invalid number: {}", n))?,
509-
),
510-
),
511-
serde_json::Value::Bool(b) => (BasicValueType::Bool, BasicValue::Bool(*b)),
512-
_ => bail!("Unsupported value type: {}", v.value),
513-
};
514-
(
515-
AnalyzedValueMapping::Literal {
516-
value: value::Value::Basic(basic_value),
517-
},
518-
EnrichedValueType {
519-
typ: ValueType::Basic(value_type),
520-
nullable: false,
521-
attrs: Default::default(),
522-
},
523-
)
496+
ValueMapping::Constant(v) => {
497+
let value = value::Value::from_json(v.value.clone(), &v.schema.typ)?;
498+
(AnalyzedValueMapping::Constant { value }, v.schema.clone())
524499
}
525500

526501
ValueMapping::Field(v) => {

src/builder/flow_builder.rs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,19 @@ impl DataScopeRef {
158158
#[pyclass]
159159
#[derive(Debug, Clone)]
160160
pub struct DataType {
161-
typ: schema::EnrichedValueType,
161+
schema: schema::EnrichedValueType,
162162
}
163163

164164
impl From<schema::EnrichedValueType> for DataType {
165-
fn from(typ: schema::EnrichedValueType) -> Self {
166-
Self { typ }
165+
fn from(schema: schema::EnrichedValueType) -> Self {
166+
Self { schema }
167167
}
168168
}
169169

170170
#[pymethods]
171171
impl DataType {
172172
pub fn __str__(&self) -> String {
173-
format!("{}", self.typ)
173+
format!("{}", self.schema)
174174
}
175175

176176
pub fn __repr__(&self) -> String {
@@ -201,7 +201,7 @@ impl DataSlice {
201201
}
202202

203203
pub fn field(&self, field_name: &str) -> PyResult<Option<DataSlice>> {
204-
let field_schema = match &self.data_type.typ.typ {
204+
let field_schema = match &self.data_type.schema.typ {
205205
schema::ValueType::Struct(struct_type) => {
206206
match struct_type.fields.iter().find(|f| f.name == field_name) {
207207
Some(field) => field,
@@ -232,7 +232,7 @@ impl DataSlice {
232232
.map(|f| f.spec.clone())
233233
.ok_or_else(|| PyException::new_err(format!("field {} not found", field_name)))?,
234234

235-
spec::ValueMapping::Literal { .. } => {
235+
spec::ValueMapping::Constant { .. } => {
236236
return Err(PyException::new_err(
237237
"field access not supported for literal",
238238
))
@@ -277,7 +277,7 @@ impl std::fmt::Display for DataSlice {
277277
write!(
278278
f,
279279
"DataSlice({}; {} {}) ",
280-
self.data_type.typ, self.scope, self.value
280+
self.data_type.schema, self.scope, self.value
281281
)?;
282282
Ok(())
283283
}
@@ -420,6 +420,24 @@ impl FlowBuilder {
420420
Ok(result)
421421
}
422422

423+
pub fn constant<'py>(
424+
&self,
425+
value_type: py::Pythonized<schema::EnrichedValueType>,
426+
value: Bound<'py, PyAny>,
427+
) -> PyResult<DataSlice> {
428+
let schema = value_type.into_inner();
429+
let value = py::value_from_py_object(&schema.typ, &value)?;
430+
let slice = DataSlice {
431+
scope: self.root_data_scope_ref.clone(),
432+
value: Arc::new(spec::ValueMapping::Constant(spec::ConstantMapping {
433+
schema: schema.clone(),
434+
value: serde_json::to_value(value).into_py_result()?,
435+
})),
436+
data_type: schema.into(),
437+
};
438+
Ok(slice)
439+
}
440+
423441
pub fn add_direct_input(
424442
&mut self,
425443
name: String,
@@ -533,7 +551,7 @@ impl FlowBuilder {
533551
.into_iter()
534552
.map(|(name, ds)| FieldSchema {
535553
name,
536-
value_type: ds.data_type.typ,
554+
value_type: ds.data_type.schema,
537555
})
538556
.collect(),
539557
),
@@ -600,7 +618,7 @@ impl FlowBuilder {
600618
scope: None,
601619
field_path: spec::FieldPath(vec![field_name.to_string()]),
602620
})),
603-
data_type: DataType { typ: field_type },
621+
data_type: DataType { schema: field_type },
604622
}))
605623
}
606624

src/builder/plan.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub struct AnalyzedStructMapping {
4545
#[derive(Debug, Clone, Serialize)]
4646
#[serde(tag = "kind")]
4747
pub enum AnalyzedValueMapping {
48-
Literal { value: value::Value },
48+
Constant { value: value::Value },
4949
Field(AnalyzedFieldReference),
5050
Struct(AnalyzedStructMapping),
5151
}

src/execution/evaluator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ fn assemble_value(
266266
scoped_entries: RefList<'_, &ScopeEntry<'_>>,
267267
) -> value::Value {
268268
match value_mapping {
269-
AnalyzedValueMapping::Literal { value } => value.clone(),
269+
AnalyzedValueMapping::Constant { value } => value.clone(),
270270
AnalyzedValueMapping::Field(field_ref) => scoped_entries
271271
.headn(field_ref.scope_up_level as usize)
272272
.unwrap()

src/ops/factory_bases.rs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,20 @@ pub struct ResolvedOpArg {
2424
pub idx: usize,
2525
}
2626

27-
impl ResolvedOpArg {
28-
pub fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
27+
pub trait ResolvedOpArgExt: Sized {
28+
type ValueType;
29+
type ValueRef<'a>;
30+
31+
fn expect_type(self, expected_type: &ValueType) -> Result<Self>;
32+
fn value<'a>(&self, args: &'a Vec<value::Value>) -> Result<Self::ValueRef<'a>>;
33+
fn take_value(&self, args: &mut Vec<value::Value>) -> Result<Self::ValueType>;
34+
}
35+
36+
impl ResolvedOpArgExt for ResolvedOpArg {
37+
type ValueType = value::Value;
38+
type ValueRef<'a> = &'a value::Value;
39+
40+
fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
2941
if &self.typ.typ != expected_type {
3042
api_bail!(
3143
"Expected argument `{}` to be of type `{}`, got `{}`",
@@ -37,7 +49,7 @@ impl ResolvedOpArg {
3749
Ok(self)
3850
}
3951

40-
pub fn value<'a>(&self, args: &'a Vec<value::Value>) -> Result<&'a value::Value> {
52+
fn value<'a>(&self, args: &'a Vec<value::Value>) -> Result<&'a value::Value> {
4153
if self.idx >= args.len() {
4254
api_bail!(
4355
"Two few arguments, {} provided, expected at least {} for `{}`",
@@ -49,7 +61,7 @@ impl ResolvedOpArg {
4961
Ok(&args[self.idx])
5062
}
5163

52-
pub fn take_value(&self, args: &mut Vec<value::Value>) -> Result<value::Value> {
64+
fn take_value(&self, args: &mut Vec<value::Value>) -> Result<value::Value> {
5365
if self.idx >= args.len() {
5466
api_bail!(
5567
"Two few arguments, {} provided, expected at least {} for `{}`",
@@ -62,6 +74,23 @@ impl ResolvedOpArg {
6274
}
6375
}
6476

77+
impl ResolvedOpArgExt for Option<ResolvedOpArg> {
78+
type ValueType = Option<value::Value>;
79+
type ValueRef<'a> = Option<&'a value::Value>;
80+
81+
fn expect_type(self, expected_type: &ValueType) -> Result<Self> {
82+
self.map(|arg| arg.expect_type(expected_type)).transpose()
83+
}
84+
85+
fn value<'a>(&self, args: &'a Vec<value::Value>) -> Result<Option<&'a value::Value>> {
86+
self.as_ref().map(|arg| arg.value(args)).transpose()
87+
}
88+
89+
fn take_value(&self, args: &mut Vec<value::Value>) -> Result<Option<value::Value>> {
90+
self.as_ref().map(|arg| arg.take_value(args)).transpose()
91+
}
92+
}
93+
6594
pub struct OpArgsResolver<'a> {
6695
args: &'a [OpArgSchema],
6796
num_positional_args: usize,

0 commit comments

Comments
 (0)