Skip to content

Commit 2303ec9

Browse files
likun666661lriggs
authored andcommitted
optimize gandiva cache
1 parent 713c57a commit 2303ec9

14 files changed

+256
-72
lines changed

cpp/src/gandiva/annotator.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,25 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc,
9393
}
9494
}
9595

96+
const Status Annotator::CheckEvalBatchFieldType(
97+
const arrow::RecordBatch& record_batch) const {
98+
for (int i = 0; i < record_batch.num_columns(); ++i) {
99+
const std::string& name = record_batch.column_name(i);
100+
auto found = in_name_to_desc_.find(name);
101+
if (found == in_name_to_desc_.end()) {
102+
// skip columns not involved in the expression.
103+
continue;
104+
}
105+
if (record_batch.column(i)->type_id() != found->second->Type()->id()) {
106+
return Status::ExecutionError("Expect field ", name, " type is ",
107+
found->second->Type()->ToString(), ", input field ",
108+
name, " type is ",
109+
record_batch.column(i)->type()->ToString());
110+
}
111+
}
112+
return Status::OK();
113+
}
114+
96115
EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch,
97116
const ArrayDataVector& out_vector) const {
98117
EvalBatchPtr eval_batch = std::make_shared<EvalBatch>(

cpp/src/gandiva/annotator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class GANDIVA_EXPORT Annotator {
6060
EvalBatchPtr PrepareEvalBatch(const arrow::RecordBatch& record_batch,
6161
const ArrayDataVector& out_vector) const;
6262

63+
const Status CheckEvalBatchFieldType(const arrow::RecordBatch& record_batch) const;
64+
6365
int buffer_count() const { return buffer_count_; }
6466

6567
private:

cpp/src/gandiva/expr_validator.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,6 @@ Status ExprValidator::Visit(const FieldNode& node) {
7272
Status::ExpressionValidationError("Field ", node.field()->name(),
7373
" has unsupported data type ",
7474
node.return_type()->name()));
75-
76-
// Ensure that field is found in schema
77-
auto field_in_schema_entry = field_map_.find(node.field()->name());
78-
ARROW_RETURN_IF(field_in_schema_entry == field_map_.end(),
79-
Status::ExpressionValidationError("Field ", node.field()->name(),
80-
" not in schema."));
81-
82-
// Ensure that the found field matches.
83-
FieldPtr field_in_schema = field_in_schema_entry->second;
84-
ARROW_RETURN_IF(!field_in_schema->Equals(node.field()),
85-
Status::ExpressionValidationError(
86-
"Field definition in schema ", field_in_schema->ToString(),
87-
" different from field in expression ", node.field()->ToString()));
88-
8975
return Status::OK();
9076
}
9177

cpp/src/gandiva/expression.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ namespace gandiva {
2222

2323
std::string Expression::ToString() { return root()->ToString(); }
2424

25+
std::string Expression::ToCacheKeyString() {return root()->ToCacheKeyString();}
26+
2527
} // namespace gandiva

cpp/src/gandiva/expression.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class GANDIVA_EXPORT Expression {
3838

3939
std::string ToString();
4040

41+
std::string ToCacheKeyString();
42+
4143
private:
4244
const NodePtr root_;
4345
const FieldPtr result_;

cpp/src/gandiva/expression_cache_key.h

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,45 +34,42 @@ class ExpressionCacheKey {
3434
public:
3535
ExpressionCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration,
3636
ExpressionVector expression_vector, SelectionVector::Mode mode)
37-
: schema_(schema), mode_(mode), uniquifier_(0), configuration_(configuration) {
37+
: mode_(mode), uniqifier_(0), configuration_(configuration) {
3838
static const int kSeedValue = 4;
3939
size_t result = kSeedValue;
4040
for (auto& expr : expression_vector) {
41-
std::string expr_as_string = expr->ToString();
42-
expressions_as_strings_.push_back(expr_as_string);
43-
arrow::internal::hash_combine(result, expr_as_string);
44-
UpdateUniquifier(expr_as_string);
41+
std::string expr_cache_key_string = expr->ToCacheKeyString();
42+
expressions_as_cache_key_strings_.push_back(expr_cache_key_string);
43+
arrow::internal::hash_combine(result, expr_cache_key_string);
44+
UpdateUniqifier(expr_cache_key_string);
4545
}
4646
arrow::internal::hash_combine(result, static_cast<size_t>(mode));
4747
arrow::internal::hash_combine(result, configuration->Hash());
48-
arrow::internal::hash_combine(result, schema_->ToString());
49-
arrow::internal::hash_combine(result, uniquifier_);
48+
arrow::internal::hash_combine(result, uniqifier_);
5049
hash_code_ = result;
5150
}
5251

5352
ExpressionCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration,
5453
Expression& expression)
55-
: schema_(schema),
56-
mode_(SelectionVector::MODE_NONE),
57-
uniquifier_(0),
54+
:mode_(SelectionVector::MODE_NONE),
55+
uniqifier_(0),
5856
configuration_(configuration) {
5957
static const int kSeedValue = 4;
6058
size_t result = kSeedValue;
61-
expressions_as_strings_.push_back(expression.ToString());
62-
UpdateUniquifier(expression.ToString());
63-
59+
expressions_as_cache_key_strings_.push_back(expression.ToCacheKeyString());
60+
UpdateUniqifier(expression.ToCacheKeyString());
61+
arrow::internal::hash_combine(result,expression.ToCacheKeyString());
6462
arrow::internal::hash_combine(result, configuration->Hash());
65-
arrow::internal::hash_combine(result, schema_->ToString());
66-
arrow::internal::hash_combine(result, uniquifier_);
63+
arrow::internal::hash_combine(result, uniqifier_);
6764
hash_code_ = result;
6865
}
6966

70-
void UpdateUniquifier(const std::string& expr) {
71-
if (uniquifier_ == 0) {
67+
void UpdateUniqifier(const std::string& expr) {
68+
if (uniqifier_ == 0) {
7269
// caching of expressions with re2 patterns causes lock contention. So, use
7370
// multiple instances to reduce contention.
7471
if (expr.find(" like(") != std::string::npos) {
75-
uniquifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16;
72+
uniqifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16;
7673
}
7774
}
7875
}
@@ -84,9 +81,6 @@ class ExpressionCacheKey {
8481
return false;
8582
}
8683

87-
if (!(schema_->Equals(*other.schema_, true))) {
88-
return false;
89-
}
9084

9185
if (configuration_ != other.configuration_) {
9286
return false;
@@ -96,11 +90,11 @@ class ExpressionCacheKey {
9690
return false;
9791
}
9892

99-
if (expressions_as_strings_ != other.expressions_as_strings_) {
93+
if (expressions_as_cache_key_strings_ != other.expressions_as_cache_key_strings_) {
10094
return false;
10195
}
10296

103-
if (uniquifier_ != other.uniquifier_) {
97+
if (uniqifier_ != other.uniqifier_) {
10498
return false;
10599
}
106100

@@ -111,10 +105,9 @@ class ExpressionCacheKey {
111105

112106
private:
113107
size_t hash_code_;
114-
SchemaPtr schema_;
115-
std::vector<std::string> expressions_as_strings_;
108+
std::vector<std::string> expressions_as_cache_key_strings_;
116109
SelectionVector::Mode mode_;
117-
uint32_t uniquifier_;
110+
uint32_t uniqifier_;
118111
std::shared_ptr<Configuration> configuration_;
119112
};
120113

cpp/src/gandiva/filter.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ Status Filter::Make(SchemaPtr schema, ConditionPtr condition,
9191
Status Filter::Evaluate(const arrow::RecordBatch& batch,
9292
std::shared_ptr<SelectionVector> out_selection) {
9393
const auto num_rows = batch.num_rows();
94-
ARROW_RETURN_IF(!batch.schema()->Equals(*schema_),
95-
Status::Invalid("RecordBatch schema must expected filter schema"));
9694
ARROW_RETURN_IF(num_rows == 0, Status::Invalid("RecordBatch must be non-empty."));
9795
ARROW_RETURN_IF(out_selection == nullptr,
9896
Status::Invalid("out_selection must be non-null."));

cpp/src/gandiva/llvm_generator.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
131131
const ArrayDataVector& output_vector) const {
132132
DCHECK_GT(record_batch.num_rows(), 0);
133133

134+
auto status = annotator_.CheckEvalBatchFieldType(record_batch);
135+
136+
ARROW_RETURN_IF(!status.ok(), status);
137+
134138
auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector);
135139
DCHECK_GT(eval_batch->GetNumBuffers(), 0);
136140

cpp/src/gandiva/node.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class GANDIVA_EXPORT Node {
4848

4949
virtual std::string ToString() const = 0;
5050

51+
virtual std::string ToCacheKeyString() const = 0;
52+
5153
protected:
5254
DataTypePtr return_type_;
5355
};
@@ -99,6 +101,8 @@ class GANDIVA_EXPORT LiteralNode : public Node {
99101
return ss.str();
100102
}
101103

104+
std::string ToCacheKeyString() const override { return ToString(); }
105+
102106
private:
103107
LiteralHolder holder_;
104108
bool is_null_;
@@ -117,6 +121,10 @@ class GANDIVA_EXPORT FieldNode : public Node {
117121
return "(" + field()->type()->ToString() + ") " + field()->name();
118122
}
119123

124+
std::string ToCacheKeyString() const override {
125+
return "(" + field()->type()->ToString() + ") ";
126+
}
127+
120128
private:
121129
FieldPtr field_;
122130
};
@@ -149,6 +157,24 @@ class GANDIVA_EXPORT FunctionNode : public Node {
149157
return ss.str();
150158
}
151159

160+
std::string ToCacheKeyString() const override {
161+
std::stringstream ss;
162+
ss << ((return_type() == NULLPTR) ? "untyped"
163+
: descriptor()->return_type()->ToString())
164+
<< " " << descriptor()->name() << "(";
165+
bool skip_comma = true;
166+
for (auto& child : children()) {
167+
if (skip_comma) {
168+
ss << child->ToCacheKeyString();
169+
skip_comma = false;
170+
} else {
171+
ss << ", " << child->ToCacheKeyString();
172+
}
173+
}
174+
ss << ")";
175+
return ss.str();
176+
}
177+
152178
private:
153179
FuncDescriptorPtr descriptor_;
154180
NodeVector children_;
@@ -188,6 +214,14 @@ class GANDIVA_EXPORT IfNode : public Node {
188214
return ss.str();
189215
}
190216

217+
std::string ToCacheKeyString() const override {
218+
std::stringstream ss;
219+
ss << "if (" << condition()->ToCacheKeyString() << ") { ";
220+
ss << then_node()->ToCacheKeyString() << " } else { ";
221+
ss << else_node()->ToCacheKeyString() << " }";
222+
return ss.str();
223+
}
224+
191225
private:
192226
NodePtr condition_;
193227
NodePtr then_node_;
@@ -225,6 +259,23 @@ class GANDIVA_EXPORT BooleanNode : public Node {
225259
return ss.str();
226260
}
227261

262+
std::string ToCacheKeyString() const override {
263+
std::stringstream ss;
264+
bool first = true;
265+
for (auto& child : children_) {
266+
if (!first) {
267+
if (expr_type() == BooleanNode::AND) {
268+
ss << " && ";
269+
} else {
270+
ss << " || ";
271+
}
272+
}
273+
ss << child->ToCacheKeyString();
274+
first = false;
275+
}
276+
return ss.str();
277+
}
278+
228279
private:
229280
ExprType expr_type_;
230281
NodeVector children_;
@@ -265,6 +316,22 @@ class InExpressionNode : public Node {
265316
return ss.str();
266317
}
267318

319+
std::string ToCacheKeyString() const override {
320+
std::stringstream ss;
321+
ss << eval_expr_->ToCacheKeyString() << " IN (";
322+
bool add_comma = false;
323+
for (auto& value : values_) {
324+
if (add_comma) {
325+
ss << ", ";
326+
}
327+
// add type in the front to differentiate
328+
ss << value;
329+
add_comma = true;
330+
}
331+
ss << ")";
332+
return ss.str();
333+
}
334+
268335
private:
269336
NodePtr eval_expr_;
270337
std::unordered_set<Type> values_;
@@ -309,6 +376,22 @@ class InExpressionNode<gandiva::DecimalScalar128> : public Node {
309376
return ss.str();
310377
}
311378

379+
std::string ToCacheKeyString() const override {
380+
std::stringstream ss;
381+
ss << eval_expr_->ToCacheKeyString() << " IN (";
382+
bool add_comma = false;
383+
for (auto& value : values_) {
384+
if (add_comma) {
385+
ss << ", ";
386+
}
387+
// add type in the front to differentiate
388+
ss << value;
389+
add_comma = true;
390+
}
391+
ss << ")";
392+
return ss.str();
393+
}
394+
312395
private:
313396
NodePtr eval_expr_;
314397
std::unordered_set<gandiva::DecimalScalar128> values_;

cpp/src/gandiva/projector.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,6 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records,
230230
}
231231

232232
Status Projector::ValidateEvaluateArgsCommon(const arrow::RecordBatch& batch) const {
233-
ARROW_RETURN_IF(!batch.schema()->Equals(*schema_),
234-
Status::Invalid("Schema in RecordBatch must match schema in Make()"));
235233
ARROW_RETURN_IF(batch.num_rows() == 0,
236234
Status::Invalid("RecordBatch must be non-empty."));
237235

0 commit comments

Comments
 (0)