Skip to content

Commit de2c50e

Browse files
committed
passing functionSignature instead of name and arguments
1 parent 1360292 commit de2c50e

File tree

11 files changed

+189
-224
lines changed

11 files changed

+189
-224
lines changed

src/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ set(SRCS
4747
Type.cpp
4848
Function.cpp
4949
Extension.cpp
50-
FunctionMapping.h FunctionLookup.cpp FunctionLookup.h)
50+
FunctionMapping.h
51+
FunctionLookup.cpp
52+
FunctionSignature.h)
5153

5254
add_library(substrait-cpp ${SRCS})
5355

src/Extension.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -264,48 +264,6 @@ TypeVariantPtr Extension::lookupType(const std::string &typeName) const {
264264
return nullptr;
265265
}
266266

267-
FunctionVariantPtr
268-
Extension::lookupScalarFunction(const std::string &name,
269-
const std::vector<TypePtr> &types) const {
270-
auto functionVariantIter = scalarFunctionVariantMap_.find(name);
271-
if (functionVariantIter != scalarFunctionVariantMap_.end()) {
272-
for (const auto &candidateFunctionVariant : functionVariantIter->second) {
273-
if (candidateFunctionVariant->tryMatch(types)) {
274-
return candidateFunctionVariant;
275-
}
276-
}
277-
}
278-
return nullptr;
279-
}
280-
281-
FunctionVariantPtr
282-
Extension::lookupAggregateFunction(const std::string &name,
283-
const std::vector<TypePtr> &types) const {
284-
auto functionVariantIter = aggregateFunctionVariantMap_.find(name);
285-
if (functionVariantIter != aggregateFunctionVariantMap_.end()) {
286-
for (const auto &candidateFunctionVariant : functionVariantIter->second) {
287-
if (candidateFunctionVariant->tryMatch(types)) {
288-
return candidateFunctionVariant;
289-
}
290-
}
291-
}
292-
return nullptr;
293-
}
294-
295-
FunctionVariantPtr
296-
Extension::lookupWindowFunction(const std::string &name,
297-
const std::vector<TypePtr> &types) const {
298-
auto functionVariantIter = windowFunctionVariantMap_.find(name);
299-
if (functionVariantIter != windowFunctionVariantMap_.end()) {
300-
for (const auto &candidateFunctionVariant : functionVariantIter->second) {
301-
if (candidateFunctionVariant->tryMatch(types)) {
302-
return candidateFunctionVariant;
303-
}
304-
}
305-
}
306-
return nullptr;
307-
}
308-
309267
void Extension::addScalarFunctionVariant(
310268
const FunctionVariantPtr &functionVariant) {
311269
const auto &functionVariants =

src/Extension.h

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "Function.h"
18+
#include "FunctionSignature.h"
1819
#include "Type.h"
1920

2021
namespace io::substrait {
@@ -57,28 +58,22 @@ class Extension {
5758
/// Add a type variant.
5859
void addTypeVariant(const TypeVariantPtr &functionVariant);
5960

60-
/// Lookup scalar function variant by given name and function types.
61-
/// @return matched function variant.
62-
FunctionVariantPtr
63-
lookupScalarFunction(const std::string &name,
64-
const std::vector<TypePtr> &types) const;
65-
66-
/// Lookup aggregate function variant by given name and function types.
67-
/// @return matched function variant.
68-
FunctionVariantPtr
69-
lookupAggregateFunction(const std::string &name,
70-
const std::vector<TypePtr> &types) const;
71-
72-
/// Lookup window function variant by given name and function types.
73-
/// @return matched function variant.
74-
FunctionVariantPtr
75-
lookupWindowFunction(const std::string &name,
76-
const std::vector<TypePtr> &types) const;
77-
7861
/// Lookup type variant by given type name.
7962
/// @return matched type variant
8063
TypeVariantPtr lookupType(const std::string &typeName) const;
8164

65+
const FunctionVariantMap &scalaFunctionVariantMap() const {
66+
return scalarFunctionVariantMap_;
67+
}
68+
69+
const FunctionVariantMap &windowFunctionVariantMap() const {
70+
return windowFunctionVariantMap_;
71+
}
72+
73+
const FunctionVariantMap &aggregateFunctionVariantMap() const {
74+
return aggregateFunctionVariantMap_;
75+
}
76+
8277
private:
8378
static std::shared_ptr<Extension> loadDefault();
8479

src/Function.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
namespace io::substrait {
1919

2020
std::string
21-
FunctionVariant::signature(
22-
const std::string& name,
23-
const std::vector<FunctionArgumentPtr>& arguments) {
21+
FunctionVariant::signature(const std::string &name,
22+
const std::vector<FunctionArgumentPtr> &arguments) {
2423
std::stringstream ss;
2524
ss << name;
2625
if (!arguments.empty()) {
2726
ss << ":";
2827
for (auto it = arguments.begin(); it != arguments.end(); ++it) {
29-
const auto& typeSign = (*it)->toTypeString();
28+
const auto &typeSign = (*it)->toTypeString();
3029
if (it == arguments.end() - 1) {
3130
ss << typeSign;
3231
} else {
@@ -38,8 +37,8 @@ FunctionVariant::signature(
3837
return ss.str();
3938
}
4039

41-
bool FunctionVariant::tryMatch(
42-
const std::vector<TypePtr>& actualTypes) {
40+
bool FunctionVariant::tryMatch(const FunctionSignature &signature) {
41+
const auto &actualTypes = signature.getArguments();
4342
if (variadic.has_value()) {
4443
// return false if actual types length less than min of variadic
4544
const auto max = variadic->max;
@@ -48,24 +47,21 @@ bool FunctionVariant::tryMatch(
4847
return false;
4948
}
5049

51-
const auto& variadicArgument = arguments[0];
50+
const auto &variadicArgument = arguments[0];
5251
// actual type must same as the variadicArgument
53-
if (const auto& variadicValueArgument =
54-
std::dynamic_pointer_cast<const ValueArgument>(
55-
variadicArgument)) {
56-
for (auto& actualType : actualTypes) {
52+
if (const auto &variadicValueArgument =
53+
std::dynamic_pointer_cast<const ValueArgument>(variadicArgument)) {
54+
for (auto &actualType : actualTypes) {
5755
if (!variadicValueArgument->type->isSameAs(actualType)) {
5856
return false;
5957
}
6058
}
6159
}
62-
return true;
6360
} else {
6461
std::vector<std::shared_ptr<const ValueArgument>> valueArguments;
65-
for (const auto& argument : arguments) {
66-
if (const auto& variadicValueArgument =
67-
std::dynamic_pointer_cast<const ValueArgument>(
68-
argument)) {
62+
for (const auto &argument : arguments) {
63+
if (const auto &variadicValueArgument =
64+
std::dynamic_pointer_cast<const ValueArgument>(argument)) {
6965
valueArguments.emplace_back(variadicValueArgument);
7066
}
7167
}
@@ -76,19 +72,24 @@ bool FunctionVariant::tryMatch(
7672
}
7773

7874
for (auto i = 0; i < actualTypes.size(); i++) {
79-
const auto& valueArgument = valueArguments[i];
75+
const auto &valueArgument = valueArguments[i];
8076
if (!valueArgument->type->isSameAs(actualTypes[i])) {
8177
return false;
8278
}
8379
}
80+
}
81+
const auto &sigReturnType = signature.getReturnType();
82+
if (this->returnType && sigReturnType) {
83+
return returnType->isSameAs(sigReturnType);
84+
} else {
8485
return true;
8586
}
8687
}
8788

88-
bool AggregateFunctionVariant::tryMatch(
89-
const std::vector<TypePtr>& actualTypes) {
90-
bool matched = FunctionVariant::tryMatch(actualTypes);
89+
bool AggregateFunctionVariant::tryMatch(const FunctionSignature &signature) {
90+
bool matched = FunctionVariant::tryMatch(signature);
9191
if (!matched && intermediate) {
92+
const auto &actualTypes = signature.getArguments();
9293
if (actualTypes.size() == 1) {
9394
return intermediate->isSameAs(actualTypes[0]);
9495
}

src/Function.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "FunctionSignature.h"
1718
#include "Type.h"
1819
#include <iostream>
1920

@@ -71,7 +72,7 @@ struct FunctionVariant {
7172
std::optional<FunctionVariadic> variadic;
7273

7374
/// Test if the actual types matched with this function variant.
74-
virtual bool tryMatch(const std::vector<TypePtr> &actualTypes);
75+
virtual bool tryMatch(const FunctionSignature& signature);
7576

7677
/// Create function signature by given function name and arguments.
7778
static std::string
@@ -89,7 +90,7 @@ struct ScalarFunctionVariant : public FunctionVariant {};
8990
struct AggregateFunctionVariant : public FunctionVariant {
9091
TypePtr intermediate;
9192

92-
bool tryMatch(const std::vector<TypePtr> &actualTypes) override;
93+
bool tryMatch(const FunctionSignature& signature) override;
9394
};
9495

9596
} // namespace io::substrait

src/FunctionLookup.cpp

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,24 @@
1717
namespace io::substrait {
1818

1919
FunctionVariantPtr
20-
FunctionLookup::lookupScalarFunction(const std::string &functionName,
21-
const std::vector<TypePtr> &types) const {
22-
const auto &functionMappings = functionMapping_->scalaMapping();
23-
const auto &substraitFunctionName =
24-
functionMappings.find(functionName) != functionMappings.end()
25-
? functionMappings.at(functionName)
26-
: functionName;
27-
return extension_->lookupScalarFunction(substraitFunctionName, types);
28-
}
20+
FunctionLookup::lookupFunction(const FunctionSignature &signature) const {
21+
const auto &functionMappings = getFunctionMap();
2922

30-
FunctionVariantPtr FunctionLookup::lookupAggregateFunction(
31-
const std::string &functionName, const std::vector<TypePtr> &types) const {
32-
const auto &functionMappings = functionMapping_->aggregateMapping();
3323
const auto &substraitFunctionName =
34-
functionMappings.find(functionName) != functionMappings.end()
35-
? functionMappings.at(functionName)
36-
: functionName;
37-
return extension_->lookupAggregateFunction(substraitFunctionName, types);
38-
}
24+
functionMappings.find(signature.getName()) != functionMappings.end()
25+
? functionMappings.at(signature.getName())
26+
: signature.getName();
3927

40-
FunctionVariantPtr
41-
FunctionLookup::lookupWindowFunction(const std::string &functionName,
42-
const std::vector<TypePtr> &types) const {
43-
const auto &functionMappings = functionMapping_->windowMapping();
44-
const auto &substraitFunctionName =
45-
functionMappings.find(functionName) != functionMappings.end()
46-
? functionMappings.at(functionName)
47-
: functionName;
48-
return extension_->lookupWindowFunction(substraitFunctionName, types);
28+
const auto &functionVariants = getFunctionVariants();
29+
auto functionVariantIter = functionVariants.find(substraitFunctionName);
30+
if (functionVariantIter != functionVariants.end()) {
31+
for (const auto &candidateFunctionVariant : functionVariantIter->second) {
32+
if (candidateFunctionVariant->tryMatch(signature)) {
33+
return candidateFunctionVariant;
34+
}
35+
}
36+
}
37+
return nullptr;
4938
}
5039

5140
} // namespace io::substrait

src/FunctionLookup.h

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "Extension.h"
1818
#include "FunctionMapping.h"
19+
#include "FunctionSignature.h"
1920

2021
namespace io::substrait {
2122

@@ -25,23 +26,69 @@ class FunctionLookup {
2526
const FunctionMappingPtr &functionMapping)
2627
: extension_(extension), functionMapping_(functionMapping) {}
2728

28-
FunctionVariantPtr
29-
lookupScalarFunction(const std::string &functionName,
30-
const std::vector<TypePtr> &types) const;
29+
virtual FunctionVariantPtr
30+
lookupFunction(const FunctionSignature &signature) const;
3131

32-
FunctionVariantPtr
33-
lookupAggregateFunction(const std::string &functionName,
34-
const std::vector<TypePtr> &types) const;
32+
virtual ~FunctionLookup() {}
3533

36-
FunctionVariantPtr
37-
lookupWindowFunction(const std::string &functionName,
38-
const std::vector<TypePtr> &types) const;
34+
protected:
35+
virtual const FunctionMap getFunctionMap() const = 0;
36+
37+
virtual const FunctionVariantMap &getFunctionVariants() const = 0;
38+
39+
const FunctionMappingPtr functionMapping_;
3940

40-
private:
4141
ExtensionPtr extension_;
42-
FunctionMappingPtr functionMapping_;
4342
};
4443

4544
using FunctionLookupPtr = std::shared_ptr<const FunctionLookup>;
4645

46+
class ScalarFunctionLookup : public FunctionLookup {
47+
public:
48+
ScalarFunctionLookup(const ExtensionPtr &extension,
49+
const FunctionMappingPtr &functionMapping)
50+
: FunctionLookup(extension, functionMapping) {}
51+
52+
protected:
53+
const FunctionMap getFunctionMap() const override {
54+
return functionMapping_->scalaMapping();
55+
}
56+
57+
const FunctionVariantMap &getFunctionVariants() const override {
58+
return extension_->scalaFunctionVariantMap();
59+
}
60+
};
61+
62+
class AggregateFunctionLookup : public FunctionLookup {
63+
public:
64+
AggregateFunctionLookup(const ExtensionPtr &extension,
65+
const FunctionMappingPtr &functionMapping)
66+
: FunctionLookup(extension, functionMapping) {}
67+
68+
protected:
69+
const FunctionMap getFunctionMap() const override {
70+
return functionMapping_->aggregateMapping();
71+
}
72+
73+
const FunctionVariantMap &getFunctionVariants() const override {
74+
return extension_->aggregateFunctionVariantMap();
75+
}
76+
};
77+
78+
class WindowFunctionLookup : public FunctionLookup {
79+
public:
80+
WindowFunctionLookup(const ExtensionPtr &extension,
81+
const FunctionMappingPtr &functionMapping)
82+
: FunctionLookup(extension, functionMapping) {}
83+
84+
protected:
85+
const FunctionMap getFunctionMap() const override {
86+
return functionMapping_->windowMapping();
87+
}
88+
89+
const FunctionVariantMap &getFunctionVariants() const override {
90+
return extension_->windowFunctionVariantMap();
91+
}
92+
};
93+
4794
} // namespace io::substrait

0 commit comments

Comments
 (0)