@@ -90,6 +90,23 @@ pub enum FunctionArgKind {
9090 Field ,
9191}
9292
93+ impl FunctionArgKind {
94+ /// Check if the current argument kind matches the expected one.
95+ pub fn expect (
96+ & self ,
97+ expected_arg_kind : FunctionArgKind ,
98+ ) -> Result < ( ) , FunctionArgKindMismatchError > {
99+ if self == & expected_arg_kind {
100+ Ok ( ( ) )
101+ } else {
102+ Err ( FunctionArgKindMismatchError {
103+ expected : expected_arg_kind,
104+ actual : * self ,
105+ } )
106+ }
107+ }
108+ }
109+
93110/// An error that occurs on a kind mismatch.
94111#[ derive( Debug , PartialEq , Eq , Error ) ]
95112#[ error( "expected argument of kind {expected:?}, but got {actual:?}" ) ]
@@ -167,11 +184,8 @@ pub enum FunctionParam<'a> {
167184}
168185
169186impl From < & FunctionParam < ' _ > > for FunctionArgKind {
170- fn from ( arg : & FunctionParam < ' _ > ) -> Self {
171- match arg {
172- FunctionParam :: Constant ( _) => FunctionArgKind :: Literal ,
173- FunctionParam :: Variable ( _) => FunctionArgKind :: Field ,
174- }
187+ fn from ( param : & FunctionParam < ' _ > ) -> Self {
188+ param. arg_kind ( )
175189 }
176190}
177191
@@ -207,21 +221,11 @@ impl<'a> FunctionParam<'a> {
207221 }
208222 }
209223
210- /// Check if the arg_kind of current paramater matches the expected_arg_kind
211- pub fn expect_arg_kind (
212- & self ,
213- expected_arg_kind : FunctionArgKind ,
214- ) -> Result < ( ) , FunctionParamError > {
215- let kind = self . into ( ) ;
216- if kind == expected_arg_kind {
217- Ok ( ( ) )
218- } else {
219- Err ( FunctionParamError :: KindMismatch (
220- FunctionArgKindMismatchError {
221- expected : expected_arg_kind,
222- actual : kind,
223- } ,
224- ) )
224+ /// Returns the associated argument kind.
225+ pub fn arg_kind ( & self ) -> FunctionArgKind {
226+ match self {
227+ FunctionParam :: Constant ( _) => FunctionArgKind :: Literal ,
228+ FunctionParam :: Variable ( _) => FunctionArgKind :: Field ,
225229 }
226230 }
227231
@@ -434,11 +438,32 @@ impl PartialEq for SimpleFunctionImpl {
434438
435439impl Eq for SimpleFunctionImpl { }
436440
441+ /// Kind of argument the function parameter expects.
442+ #[ derive( Debug , PartialEq , Eq , Clone , Copy ) ]
443+ pub enum SimpleFunctionArgKind {
444+ /// The parameter is expecting a literal value.
445+ Literal ,
446+ /// The parameter is expecting a field / dynamic value.
447+ Field ,
448+ /// The parameter is expecting either a literal or a field / dynamic value.
449+ Both ,
450+ }
451+
452+ impl SimpleFunctionArgKind {
453+ fn expect ( & self , arg_kind : FunctionArgKind ) -> Result < ( ) , FunctionArgKindMismatchError > {
454+ match self {
455+ SimpleFunctionArgKind :: Literal => arg_kind. expect ( FunctionArgKind :: Literal ) ,
456+ SimpleFunctionArgKind :: Field => arg_kind. expect ( FunctionArgKind :: Field ) ,
457+ SimpleFunctionArgKind :: Both => Ok ( ( ) ) ,
458+ }
459+ }
460+ }
461+
437462/// Defines a mandatory function argument.
438463#[ derive( Debug , PartialEq , Eq , Clone ) ]
439464pub struct SimpleFunctionParam {
440465 /// How the argument can be specified when calling a function.
441- pub arg_kind : FunctionArgKind ,
466+ pub arg_kind : SimpleFunctionArgKind ,
442467 /// The type of its associated value.
443468 pub val_type : Type ,
444469}
@@ -447,7 +472,7 @@ pub struct SimpleFunctionParam {
447472#[ derive( Debug , PartialEq , Eq , Clone ) ]
448473pub struct SimpleFunctionOptParam {
449474 /// How the argument can be specified when calling a function.
450- pub arg_kind : FunctionArgKind ,
475+ pub arg_kind : SimpleFunctionArgKind ,
451476 /// The default value if the argument is missing.
452477 pub default_value : LhsValue < ' static > ,
453478}
@@ -476,11 +501,11 @@ impl FunctionDefinition for SimpleFunctionDefinition {
476501 let index = params. len ( ) ;
477502 if index < self . params . len ( ) {
478503 let param = & self . params [ index] ;
479- next_param . expect_arg_kind ( param . arg_kind ) ?;
504+ param . arg_kind . expect ( next_param . arg_kind ( ) ) ?;
480505 next_param. expect_val_type ( once ( ExpectedType :: Type ( param. val_type ) ) ) ?;
481506 } else if index < self . params . len ( ) + self . opt_params . len ( ) {
482507 let opt_param = & self . opt_params [ index - self . params . len ( ) ] ;
483- next_param . expect_arg_kind ( opt_param . arg_kind ) ?;
508+ opt_param . arg_kind . expect ( next_param . arg_kind ( ) ) ?;
484509 next_param
485510 . expect_val_type ( once ( ExpectedType :: Type ( opt_param. default_value . get_type ( ) ) ) ) ?;
486511 } else {
0 commit comments