Skip to content

Commit fc829be

Browse files
committed
passing functionSignature instead of name and arguments
1 parent de2c50e commit fc829be

File tree

4 files changed

+47
-62
lines changed

4 files changed

+47
-62
lines changed

src/Function.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ FunctionVariant::signature(const std::string &name,
3838
}
3939

4040
bool FunctionVariant::tryMatch(const FunctionSignature &signature) {
41-
const auto &actualTypes = signature.getArguments();
41+
const auto &actualTypes = signature.arguments;
4242
if (variadic.has_value()) {
4343
// return false if actual types length less than min of variadic
4444
const auto max = variadic->max;
@@ -78,7 +78,7 @@ bool FunctionVariant::tryMatch(const FunctionSignature &signature) {
7878
}
7979
}
8080
}
81-
const auto &sigReturnType = signature.getReturnType();
81+
const auto &sigReturnType = signature.returnType;
8282
if (this->returnType && sigReturnType) {
8383
return returnType->isSameAs(sigReturnType);
8484
} else {
@@ -89,7 +89,7 @@ bool FunctionVariant::tryMatch(const FunctionSignature &signature) {
8989
bool AggregateFunctionVariant::tryMatch(const FunctionSignature &signature) {
9090
bool matched = FunctionVariant::tryMatch(signature);
9191
if (!matched && intermediate) {
92-
const auto &actualTypes = signature.getArguments();
92+
const auto &actualTypes = signature.arguments;
9393
if (actualTypes.size() == 1) {
9494
return intermediate->isSameAs(actualTypes[0]);
9595
}

src/FunctionLookup.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ FunctionLookup::lookupFunction(const FunctionSignature &signature) const {
2121
const auto &functionMappings = getFunctionMap();
2222

2323
const auto &substraitFunctionName =
24-
functionMappings.find(signature.getName()) != functionMappings.end()
25-
? functionMappings.at(signature.getName())
26-
: signature.getName();
24+
functionMappings.find(signature.name) != functionMappings.end()
25+
? functionMappings.at(signature.name)
26+
: signature.name;
2727

2828
const auto &functionVariants = getFunctionVariants();
2929
auto functionVariantIter = functionVariants.find(substraitFunctionName);

src/FunctionSignature.h

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,10 @@
2424

2525
namespace io::substrait {
2626

27-
class FunctionSignature {
28-
public:
29-
/// construct the substrait function signature with function name, return type
30-
/// and arguments.
31-
FunctionSignature(const std::string &name,
32-
const std::vector<TypePtr> &arguments,
33-
const TypePtr &returnType)
34-
: name_(name), arguments_(arguments), returnType_(returnType) {}
35-
36-
const std::string getName() const { return name_; }
37-
38-
const std::vector<TypePtr> getArguments() const { return arguments_; }
39-
40-
const TypePtr getReturnType() const { return returnType_; }
41-
42-
private:
43-
const std::string name_;
44-
const std::vector<TypePtr> arguments_;
45-
const TypePtr returnType_;
27+
struct FunctionSignature {
28+
std::string name;
29+
std::vector<TypePtr> arguments;
30+
TypePtr returnType;
4631
};
4732

4833
} // namespace io::substrait

src/tests/FunctionLookupTest.cpp

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,20 @@ class SubstraitFunctionLookupTest : public ::testing::Test {
4646
std::make_shared<AggregateFunctionLookup>(extension_, mappings_);
4747
}
4848

49-
void testScalarFunctionLookup(const std::string &name,
50-
const std::vector<TypePtr> &arguments,
51-
const TypePtr &returnType,
49+
void testScalarFunctionLookup(const FunctionSignature &inputSignature,
5250
const std::string &outputSignature) {
53-
const auto &functionVariant = scalarFunctionLookup_->lookupFunction(
54-
FunctionSignature(name, arguments, returnType));
51+
const auto &functionVariant =
52+
scalarFunctionLookup_->lookupFunction(inputSignature);
5553

5654
ASSERT_TRUE(functionVariant != nullptr);
5755
ASSERT_EQ(functionVariant->signature(), outputSignature);
5856
}
5957

60-
void testAggregateFunctionLookup(const std::string &name,
61-
const std::vector<TypePtr> &arguments,
62-
const TypePtr &returnType,
58+
void testAggregateFunctionLookup(const FunctionSignature &inputSignature,
6359
const std::string &outputSignature) {
6460

65-
const auto &functionVariant = aggregateFunctionLookup_->lookupFunction(
66-
FunctionSignature(name, arguments, returnType));
61+
const auto &functionVariant =
62+
aggregateFunctionLookup_->lookupFunction(inputSignature);
6763

6864
ASSERT_TRUE(functionVariant != nullptr);
6965
ASSERT_EQ(functionVariant->signature(), outputSignature);
@@ -75,51 +71,55 @@ class SubstraitFunctionLookupTest : public ::testing::Test {
7571
};
7672

7773
TEST_F(SubstraitFunctionLookupTest, compare_function) {
78-
testScalarFunctionLookup("lt", {kI8(), kI8()}, kBool(), "lt:any1_any1");
74+
testScalarFunctionLookup({"lt", {kI8(), kI8()}, kBool()}, "lt:any1_any1");
7975

80-
testScalarFunctionLookup("lt", {kI16(), kI16()}, kBool(), "lt:any1_any1");
76+
testScalarFunctionLookup({"lt", {kI16(), kI16()}, kBool()}, "lt:any1_any1");
8177

82-
testScalarFunctionLookup("lt", {kI32(), kI32()}, kBool(), "lt:any1_any1");
78+
testScalarFunctionLookup({"lt", {kI32(), kI32()}, kBool()}, "lt:any1_any1");
8379

84-
testScalarFunctionLookup("lt", {kI64(), kI64()}, kBool(), "lt:any1_any1");
80+
testScalarFunctionLookup({"lt", {kI64(), kI64()}, kBool()}, "lt:any1_any1");
8581

86-
testScalarFunctionLookup("lt", {kFp32(), kFp32()}, kBool(), "lt:any1_any1");
82+
testScalarFunctionLookup({"lt", {kFp32(), kFp32()}, kBool()}, "lt:any1_any1");
8783

88-
testScalarFunctionLookup("lt", {kFp64(), kFp64()}, kBool(), "lt:any1_any1");
89-
testScalarFunctionLookup("between", {kI8(), kI8(), kI8()}, kBool(),
84+
testScalarFunctionLookup({"lt", {kFp64(), kFp64()}, kBool()}, "lt:any1_any1");
85+
testScalarFunctionLookup({"between", {kI8(), kI8(), kI8()}, kBool()},
9086
"between:any1_any1_any1");
9187
}
9288

9389
TEST_F(SubstraitFunctionLookupTest, arithmetic_function) {
94-
testScalarFunctionLookup("add", {kI8(), kI8()}, kI8(), "add:opt_i8_i8");
95-
96-
testScalarFunctionLookup("plus", {kI8(), kI8()}, kI8(), "add:opt_i8_i8");
97-
testScalarFunctionLookup("divide",
98-
{
99-
kFp32(),
100-
kFp32(),
101-
},
102-
kFp32(), "divide:opt_opt_opt_fp32_fp32");
103-
104-
testAggregateFunctionLookup("avg", {Type::decode("struct<fp64,i64>")},
105-
kFp32(), "avg:opt_fp32");
90+
testScalarFunctionLookup({"add", {kI8(), kI8()}, kI8()}, "add:opt_i8_i8");
91+
92+
testScalarFunctionLookup({"plus", {kI8(), kI8()}, kI8()}, "add:opt_i8_i8");
93+
testScalarFunctionLookup({"divide",
94+
{
95+
kFp32(),
96+
kFp32(),
97+
},
98+
kFp32()},
99+
"divide:opt_opt_opt_fp32_fp32");
100+
101+
testAggregateFunctionLookup(
102+
{"avg", {Type::decode("struct<fp64,i64>")}, kFp32()}, "avg:opt_fp32");
106103
}
107104

108105
TEST_F(SubstraitFunctionLookupTest, avg) {}
109106

110107
TEST_F(SubstraitFunctionLookupTest, logical) {
111-
testScalarFunctionLookup("and", {kBool(), kBool()}, kBool(), "and:bool");
112-
testScalarFunctionLookup("or", {kBool(), kBool()}, kBool(), "or:bool");
113-
testScalarFunctionLookup("not", {kBool()}, kBool(), "not:bool");
114-
testScalarFunctionLookup("xor", {kBool(), kBool()}, kBool(), "xor:bool_bool");
108+
testScalarFunctionLookup({"and", {kBool(), kBool()}, kBool()}, "and:bool");
109+
testScalarFunctionLookup({"or", {kBool(), kBool()}, kBool()}, "or:bool");
110+
testScalarFunctionLookup({"not", {kBool()}, kBool()}, "not:bool");
111+
testScalarFunctionLookup({"xor", {kBool(), kBool()}, kBool()},
112+
"xor:bool_bool");
115113
}
116114

117115
TEST_F(SubstraitFunctionLookupTest, string_function) {
118-
testScalarFunctionLookup("like", {kString(), kString()}, kBool(),
116+
testScalarFunctionLookup({"like", {kString(), kString()}, kBool()},
119117
"like:opt_str_str");
120118
testScalarFunctionLookup(
121-
"like", {Type::decode("varchar<L1>"), Type::decode("varchar<L2>")},
122-
kBool(), "like:opt_vchar<L1>_vchar<L2>");
123-
testScalarFunctionLookup("substr", {kString(), kI32(), kI32()}, kString(),
119+
{"like",
120+
{Type::decode("varchar<L1>"), Type::decode("varchar<L2>")},
121+
kBool()},
122+
"like:opt_vchar<L1>_vchar<L2>");
123+
testScalarFunctionLookup({"substr", {kString(), kI32(), kI32()}, kString()},
124124
"substring:str_i32_i32");
125125
}

0 commit comments

Comments
 (0)