1818namespace io ::substrait {
1919
2020std::string
21- FunctionVariant::signature (
22- const std::string& name,
23- const std::vector<FunctionArgumentPtr>& arguments) {
21+ FunctionVariant::signature (const std::string &name,
22+ const std::vector<FunctionArgumentPtr> &arguments) {
2423 std::stringstream ss;
2524 ss << name;
2625 if (!arguments.empty ()) {
2726 ss << " :" ;
2827 for (auto it = arguments.begin (); it != arguments.end (); ++it) {
29- const auto & typeSign = (*it)->toTypeString ();
28+ const auto & typeSign = (*it)->toTypeString ();
3029 if (it == arguments.end () - 1 ) {
3130 ss << typeSign;
3231 } else {
@@ -38,8 +37,8 @@ FunctionVariant::signature(
3837 return ss.str ();
3938}
4039
41- bool FunctionVariant::tryMatch (
42- const std::vector<TypePtr>& actualTypes) {
40+ bool FunctionVariant::tryMatch (const FunctionSignature &signature) {
41+ const auto & actualTypes = signature. getArguments ();
4342 if (variadic.has_value ()) {
4443 // return false if actual types length less than min of variadic
4544 const auto max = variadic->max ;
@@ -48,24 +47,21 @@ bool FunctionVariant::tryMatch(
4847 return false ;
4948 }
5049
51- const auto & variadicArgument = arguments[0 ];
50+ const auto & variadicArgument = arguments[0 ];
5251 // actual type must same as the variadicArgument
53- if (const auto & variadicValueArgument =
54- std::dynamic_pointer_cast<const ValueArgument>(
55- variadicArgument)) {
56- for (auto & actualType : actualTypes) {
52+ if (const auto &variadicValueArgument =
53+ std::dynamic_pointer_cast<const ValueArgument>(variadicArgument)) {
54+ for (auto &actualType : actualTypes) {
5755 if (!variadicValueArgument->type ->isSameAs (actualType)) {
5856 return false ;
5957 }
6058 }
6159 }
62- return true ;
6360 } else {
6461 std::vector<std::shared_ptr<const ValueArgument>> valueArguments;
65- for (const auto & argument : arguments) {
66- if (const auto & variadicValueArgument =
67- std::dynamic_pointer_cast<const ValueArgument>(
68- argument)) {
62+ for (const auto &argument : arguments) {
63+ if (const auto &variadicValueArgument =
64+ std::dynamic_pointer_cast<const ValueArgument>(argument)) {
6965 valueArguments.emplace_back (variadicValueArgument);
7066 }
7167 }
@@ -76,19 +72,24 @@ bool FunctionVariant::tryMatch(
7672 }
7773
7874 for (auto i = 0 ; i < actualTypes.size (); i++) {
79- const auto & valueArgument = valueArguments[i];
75+ const auto & valueArgument = valueArguments[i];
8076 if (!valueArgument->type ->isSameAs (actualTypes[i])) {
8177 return false ;
8278 }
8379 }
80+ }
81+ const auto &sigReturnType = signature.getReturnType ();
82+ if (this ->returnType && sigReturnType) {
83+ return returnType->isSameAs (sigReturnType);
84+ } else {
8485 return true ;
8586 }
8687}
8788
88- bool AggregateFunctionVariant::tryMatch (
89- const std::vector<TypePtr>& actualTypes) {
90- bool matched = FunctionVariant::tryMatch (actualTypes);
89+ bool AggregateFunctionVariant::tryMatch (const FunctionSignature &signature) {
90+ bool matched = FunctionVariant::tryMatch (signature);
9191 if (!matched && intermediate) {
92+ const auto &actualTypes = signature.getArguments ();
9293 if (actualTypes.size () == 1 ) {
9394 return intermediate->isSameAs (actualTypes[0 ]);
9495 }
0 commit comments