@@ -377,7 +377,9 @@ class Aggregate : public Node {
377
377
NodePtr input)
378
378
: groupby_count_(groupby_count)
379
379
, aggs_(std::move(aggs))
380
- , fields_(std::move(fields)) {
380
+ , fields_(std::move(fields))
381
+ , partitioned_(false )
382
+ , buffer_entry_count_hint_(0 ) {
381
383
inputs_.emplace_back (std::move (input));
382
384
}
383
385
@@ -406,6 +408,12 @@ class Aggregate : public Node {
406
408
407
409
void setAggExprs (ExprPtrVector new_aggs) { aggs_ = std::move (new_aggs); }
408
410
411
+ bool isPartitioned () const { return partitioned_; }
412
+ void setPartitioned (bool val) { partitioned_ = val; }
413
+
414
+ size_t bufferEntryCountHint () const { return buffer_entry_count_hint_; }
415
+ void setBufferEntryCountHint (size_t val) { buffer_entry_count_hint_ = val; }
416
+
409
417
void rewriteExprs (hdk::ir::ExprRewriter& rewriter) override ;
410
418
411
419
std::string toString () const override {
@@ -417,6 +425,7 @@ class Aggregate : public Node {
417
425
::toString (aggs_),
418
426
", fields=",
419
427
::toString(fields_),
428
+ (partitioned_ ? " , partitioned" : " " ),
420
429
", inputs=",
421
430
inputsToString(inputs_),
422
431
")");
@@ -445,6 +454,8 @@ class Aggregate : public Node {
445
454
const size_t groupby_count_;
446
455
ExprPtrVector aggs_;
447
456
std::vector<std::string> fields_;
457
+ bool partitioned_;
458
+ size_t buffer_entry_count_hint_;
448
459
};
449
460
450
461
class Join : public Node {
@@ -593,7 +604,9 @@ class TranslatedJoin : public Node {
593
604
CHECK(false);
594
605
return nullptr;
595
606
}
596
- const std::string& getFieldName(size_t i) const override { CHECK(false); }
607
+ const std::string& getFieldName(size_t i) const override {
608
+ throw std::runtime_error(" Unexpected call to TranslatedJoin::getFieldName." );
609
+ }
597
610
std::vector<const ColumnVar*> getJoinCols(bool lhs) const {
598
611
if (lhs) {
599
612
return lhs_join_cols_;
@@ -853,6 +866,69 @@ class LogicalUnion : public Node {
853
866
bool const is_all_;
854
867
};
855
868
869
+ struct ShuffleFunction {
870
+ enum Kind {
871
+ kHash,
872
+ };
873
+
874
+ Kind kind;
875
+ size_t partitions;
876
+
877
+ size_t hash() const;
878
+ std::string toString() const;
879
+ };
880
+
881
+ std::ostream& operator<<(std::ostream& os, const ShuffleFunction& fn);
882
+ std::ostream& operator<<(std::ostream& os, ShuffleFunction::Kind kind);
883
+
884
+ class Shuffle : public Node {
885
+ public:
886
+ Shuffle(ExprPtrVector keys,
887
+ ExprPtr expr,
888
+ std::string field,
889
+ ShuffleFunction fn,
890
+ NodePtr input);
891
+ Shuffle(ExprPtrVector keys,
892
+ ExprPtrVector exprs,
893
+ std::vector<std::string> fields,
894
+ ShuffleFunction fn,
895
+ std::vector<NodePtr> input);
896
+ Shuffle(const Shuffle& other) = default;
897
+
898
+ const ExprPtrVector& keys() const { return keys_; }
899
+ const ExprPtrVector& exprs() const { return exprs_; }
900
+ const std::vector<std::string>& fields() const { return fields_; }
901
+ ShuffleFunction fn() const { return fn_; }
902
+
903
+ size_t size() const override { return exprs_.size(); }
904
+
905
+ // Shuffle node can be used for computing partition sizes and perform
906
+ // actual partitioning. The first version uses COUNT aggregte as its
907
+ // only target expression.
908
+ bool isCount() const {
909
+ return exprs_.size() == (size_t)1 && exprs_.front()->is<AggExpr>();
910
+ }
911
+
912
+ std::string toString() const override;
913
+ size_t toHash() const override;
914
+ void rewriteExprs(hdk::ir::ExprRewriter& rewriter) override;
915
+
916
+ std::shared_ptr<Node> deepCopy() const override {
917
+ return std::make_shared<Shuffle>(*this);
918
+ }
919
+
920
+ const std::string& getFieldName(size_t i) const override {
921
+ CHECK_LT(i, fields_.size());
922
+ return fields_[i];
923
+ }
924
+
925
+ private:
926
+ ExprPtrVector keys_;
927
+ ExprPtrVector exprs_;
928
+ std::vector<std::string> fields_;
929
+ ShuffleFunction fn_;
930
+ };
931
+
856
932
class QueryNotSupported : public std::runtime_error {
857
933
public:
858
934
QueryNotSupported(const std::string& reason) : std::runtime_error(reason) {}
@@ -921,3 +997,5 @@ size_t getNodeColumnCount(const Node* node);
921
997
ExprPtr getJoinInputColumnRef(const ColumnRef* col_ref);
922
998
923
999
} // namespace hdk::ir
1000
+
1001
+ std::string toString(hdk::ir::ShuffleFunction::Kind kind);
0 commit comments