1515import itertools
1616from collections import namedtuple
1717from enum import Enum
18- from typing import Dict , List , Optional , Union , Tuple
18+ from typing import Any , Dict , List , Optional , Union , Tuple
1919
2020import numpy as np
2121import pandas as pd
2727from ...serialization .serializables import (
2828 AnyField ,
2929 BoolField ,
30+ DictField ,
3031 StringField ,
3132 TupleField ,
3233 KeyField ,
5051import logging
5152
5253logger = logging .getLogger (__name__ )
54+ DEFAULT_BLOOM_FILTER_CHUNK_THRESHOLD = 10
55+ # use bloom filter to filter large DataFrame
56+ BLOOM_FILTER_OPTIONS = [
57+ "max_elements" ,
58+ "error_rate" ,
59+ "apply_chunk_size_threshold" ,
60+ "filter" ,
61+ "combine_size" ,
62+ ]
63+ BLOOM_FILTER_ON_OPTIONS = ["large" , "small" , "both" ]
64+ DEFAULT_BLOOM_FILTER_ON = "large"
5365
5466
5567class DataFrameMergeAlign (MapReduceOperand , DataFrameOperandMixin ):
@@ -157,6 +169,7 @@ class DataFrameMerge(DataFrameOperand, DataFrameOperandMixin):
157169 auto_merge = StringField ("auto_merge" )
158170 auto_merge_threshold = Int32Field ("auto_merge_threshold" )
159171 bloom_filter = AnyField ("bloom_filter" )
172+ bloom_filter_options = DictField ("bloom_filter_options" )
160173
161174 # only for broadcast merge
162175 split_info = NamedTupleField ("split_info" )
@@ -265,24 +278,39 @@ def _apply_bloom_filter(
265278 op : "DataFrameMerge" ,
266279 ):
267280 bloom_filter_params = dict ()
268- if isinstance (op .bloom_filter , dict ):
269- if "max_elements" in op .bloom_filter :
270- bloom_filter_params ["max_elements" ] = op .bloom_filter ["max_elements" ]
271- if "error_rate" in op .bloom_filter :
272- bloom_filter_params ["error_rate" ] = op .bloom_filter ["error_rate" ]
281+ bloom_filter_options = op .bloom_filter_options or dict ()
282+ for option in ["max_elements" , "error_rate" , "combine_size" ]:
283+ if option in bloom_filter_options :
284+ bloom_filter_params [option ] = bloom_filter_options [option ]
273285 if "max_elements" not in bloom_filter_params :
274286 bloom_filter_params ["max_elements" ] = max (
275287 c .shape [0 ] for c in left .chunks + right .chunks
276288 )
277- if len (left .chunks ) > len (right .chunks ):
289+ filter_on = bloom_filter_options .get ("filter" , DEFAULT_BLOOM_FILTER_ON )
290+ if filter_on == "large" :
291+ if len (left .chunks ) > len (right .chunks ):
292+ left = filter_by_bloom_filter (
293+ left , right , left_on , right_on , ** bloom_filter_params
294+ )
295+ else :
296+ right = filter_by_bloom_filter (
297+ right , left , right_on , left_on , ** bloom_filter_params
298+ )
299+ elif filter_on == "small" :
300+ if len (left .chunks ) < len (right .chunks ):
301+ left = filter_by_bloom_filter (
302+ left , right , left_on , right_on , ** bloom_filter_params
303+ )
304+ else :
305+ right = filter_by_bloom_filter (
306+ right , left , right_on , left_on , ** bloom_filter_params
307+ )
308+ else :
309+ assert filter_on == "both"
310+ # both
278311 left = filter_by_bloom_filter (
279- left ,
280- right ,
281- left_on ,
282- right_on ,
283- ** bloom_filter_params ,
312+ left , right , left_on , right_on , ** bloom_filter_params
284313 )
285- else :
286314 right = filter_by_bloom_filter (
287315 right , left , right_on , left_on , ** bloom_filter_params
288316 )
@@ -587,15 +615,29 @@ def _if_apply_bloom_filter(
587615 op : "DataFrameMerge" ,
588616 left : TileableType ,
589617 right : TileableType ,
590- bloom_filter_chunk_threshold : int ,
591618 ):
592- if len (left .chunks + right .chunks ) <= bloom_filter_chunk_threshold :
619+ # bloom filter can only work for inner merge
620+ if op .how != "inner" or op .bloom_filter is False :
593621 return False
594- elif method == MergeMethod . shuffle and op .bloom_filter :
622+ elif op .bloom_filter is True :
595623 return True
596- else :
624+
625+ bloom_filter_options = op .bloom_filter_options or dict ()
626+ bloom_filter_chunk_threshold = bloom_filter_options .get (
627+ "apply_chunk_size_threshold" , DEFAULT_BLOOM_FILTER_CHUNK_THRESHOLD
628+ )
629+
630+ # TODO(hks): disable bloom_filter for now, when it is ready, turn it on them
631+ # bloom_filter == auto
632+ if len (left .chunks + right .chunks ) <= bloom_filter_chunk_threshold :
633+ # if size of input chunks <= threshold, skip bloom filter
634+ return False
635+ elif method == MergeMethod .shuffle :
636+ # for shuffle, enable bloom filter by default
597637 return False
598638
639+ return False
640+
599641 @classmethod
600642 def tile (cls , op : "DataFrameMerge" ):
601643 left = build_concatenated_rows_frame (op .inputs [0 ])
@@ -612,36 +654,42 @@ def tile(cls, op: "DataFrameMerge"):
612654 yield TileStatus ([left , right ] + left .chunks + right .chunks , progress = 0.2 )
613655 left = auto_merge_chunks (ctx , left )
614656 right = auto_merge_chunks (ctx , right )
657+ logger .debug (
658+ "Before merge %s, left data count: %d, chunk size: %d, "
659+ "right data count: %d, chunk_size: %d" ,
660+ op ,
661+ left .shape [0 ],
662+ len (left .chunks ),
663+ right .shape [0 ],
664+ len (right .chunks ),
665+ )
666+ else :
667+ logger .debug (
668+ "Skip auto merge before %s, left chunk size: %d, right chunk size: %d" ,
669+ op ,
670+ len (left .chunks ),
671+ len (right .chunks ),
672+ )
615673
616674 method = cls ._choose_merge_method (op , left , right )
617- bloom_filter_chunk_threshold = 10
618- if isinstance (op .bloom_filter , dict ):
619- bloom_filter_chunk_threshold = op .bloom_filter .pop (
620- "apply_chunk_size_threshold" , bloom_filter_chunk_threshold
621- )
622- if cls ._if_apply_bloom_filter (
623- method , op , left , right , bloom_filter_chunk_threshold
624- ):
675+ if cls ._if_apply_bloom_filter (method , op , left , right ):
676+ if has_unknown_shape (left , right ): # pragma: no cover
677+ yield TileStatus (left .chunks + right .chunks , progress = 0.3 )
625678 left_on = _prepare_shuffle_on (op .left_index , op .left_on , op .on )
626679 right_on = _prepare_shuffle_on (op .right_index , op .right_on , op .on )
627- if op .how == "inner" and op .bloom_filter :
628- if has_unknown_shape (left , right ):
629- yield TileStatus (left .chunks + right .chunks , progress = 0.3 )
630- small_one = right if len (left .chunks ) > len (right .chunks ) else left
631- logger .debug (
632- "Apply bloom filter for operand %s, use DataFrame %s to build bloom filter." ,
633- op ,
634- small_one ,
635- )
636- left , right = yield from recursive_tile (
637- * cls ._apply_bloom_filter (left , right , left_on , right_on , op )
638- )
639- # auto merge after bloom filter
640- yield TileStatus (
641- [left , right ] + left .chunks + right .chunks , progress = 0.5
642- )
643- left = auto_merge_chunks (ctx , left )
644- right = auto_merge_chunks (ctx , right )
680+ small_one = right if len (left .chunks ) > len (right .chunks ) else left
681+ logger .debug (
682+ "Apply bloom filter for operand %s, use DataFrame %s to build bloom filter." ,
683+ op ,
684+ small_one ,
685+ )
686+ left , right = yield from recursive_tile (
687+ * cls ._apply_bloom_filter (left , right , left_on , right_on , op )
688+ )
689+ # auto merge after bloom filter
690+ yield TileStatus ([left , right ] + left .chunks + right .chunks , progress = 0.5 )
691+ left = auto_merge_chunks (ctx , left )
692+ right = auto_merge_chunks (ctx , right )
645693
646694 if op .method == "auto" :
647695 # if method is auto, select new method after auto merge
@@ -665,8 +713,18 @@ def tile(cls, op: "DataFrameMerge"):
665713 yield TileStatus (
666714 ret [0 ].chunks , progress = 0.8
667715 ) # trigger execution for chunks
668- return [auto_merge_chunks (get_context (), ret [0 ])]
716+ merged = auto_merge_chunks (get_context (), ret [0 ])
717+ logger .debug (
718+ "After merge %s, data size: %d, chunk size: %d" ,
719+ op ,
720+ merged .shape [0 ],
721+ len (merged .chunks ),
722+ )
723+ return [merged ]
669724 else :
725+ logger .debug (
726+ "Skip auto merge after %s, chunk size: %d" , op , len (ret [0 ].chunks )
727+ )
670728 return ret
671729
672730 @classmethod
@@ -750,7 +808,8 @@ def merge(
750808 method : str = "auto" ,
751809 auto_merge : str = "both" ,
752810 auto_merge_threshold : int = 8 ,
753- bloom_filter : Union [bool , Dict ] = True ,
811+ bloom_filter : Union [bool , str ] = "auto" ,
812+ bloom_filter_options : Dict [str , Any ] = None ,
754813) -> DataFrame :
755814 """
756815 Merge DataFrame or named Series objects with a database-style join.
@@ -843,17 +902,16 @@ def merge(
843902 When how is "inner", merged result could be much smaller than original DataFrame,
844903 if the number of chunks is greater than the threshold,
845904 it will merge small chunks automatically.
846- bloom_filter: bool or dict, default True
847- Use bloom filter to optimize merge, you can pass a dict to specify arguments for
848- bloom filter.
849-
850- If is a dict:
851-
905+ bloom_filter: bool, str, default "auto"
906+ Use bloom filter to optimize merge
907+ bloom_filter_options: dict
852908 * "max_elements": max elements in bloom filter,
853909 default value is the max size of all input chunks
854910 * "error_rate": error raite, default 0.1.
855911 * "apply_chunk_size_threshold": min chunk size of input chunks to apply bloom filter, default 10
856912 when chunk size of left and right is greater than this threshold, apply bloom filter
913+ * "filter": "large", "small", "both", default "large"
914+ decides to filter on large, small or both DataFrames.
857915
858916 Returns
859917 -------
@@ -944,8 +1002,26 @@ def merge(
9441002 raise NotImplementedError (f"{ method } merge is not supported" )
9451003 if auto_merge not in ["both" , "none" , "before" , "after" ]: # pragma: no cover
9461004 raise ValueError (
947- f"{ auto_merge } can only be `both`, `none`, `before` or `after`"
1005+ f"auto_merge can only be `both`, `none`, `before` or `after`, got { auto_merge } "
1006+ )
1007+ if bloom_filter not in [True , False , "auto" ]:
1008+ raise ValueError (
1009+ f'bloom_filter can only be True, False, or "auto", got { bloom_filter } '
9481010 )
1011+ if bloom_filter_options :
1012+ if not isinstance (bloom_filter_options , dict ):
1013+ raise TypeError (
1014+ f"bloom_filter_options must be a dict, got { type (bloom_filter_options )} "
1015+ )
1016+ for k , v in bloom_filter_options .items ():
1017+ if k not in BLOOM_FILTER_OPTIONS :
1018+ raise ValueError (
1019+ f"Invalid bloom filter option { k } , available: { BLOOM_FILTER_OPTIONS } "
1020+ )
1021+ if k == "filter" and v not in BLOOM_FILTER_ON_OPTIONS :
1022+ raise ValueError (
1023+ f"Invalid filter { k } , available: { BLOOM_FILTER_ON_OPTIONS } "
1024+ )
9491025 op = DataFrameMerge (
9501026 how = how ,
9511027 on = on ,
@@ -962,6 +1038,7 @@ def merge(
9621038 auto_merge = auto_merge ,
9631039 auto_merge_threshold = auto_merge_threshold ,
9641040 bloom_filter = bloom_filter ,
1041+ bloom_filter_options = bloom_filter_options ,
9651042 output_types = [OutputType .dataframe ],
9661043 )
9671044 return op (df , right )
@@ -979,6 +1056,7 @@ def join(
9791056 auto_merge : str = "both" ,
9801057 auto_merge_threshold : int = 8 ,
9811058 bloom_filter : Union [bool , Dict ] = True ,
1059+ bloom_filter_options : Dict [str , Any ] = None ,
9821060) -> DataFrame :
9831061 """
9841062 Join columns of another DataFrame.
@@ -1033,17 +1111,16 @@ def join(
10331111 When how is "inner", merged result could be much smaller than original DataFrame,
10341112 if the number of chunks is greater than the threshold,
10351113 it will merge small chunks automatically.
1036- bloom_filter: bool or dict, default True
1037- Use bloom filter to optimize merge, you can pass a dict to specify arguments for
1038- bloom filter.
1039-
1040- If is a dict:
1041-
1114+ bloom_filter: bool, str, default "auto"
1115+ Use bloom filter to optimize merge
1116+ bloom_filter_options: dict
10421117 * "max_elements": max elements in bloom filter,
10431118 default value is the max size of all input chunks
10441119 * "error_rate": error raite, default 0.1.
10451120 * "apply_chunk_size_threshold": min chunk size of input chunks to apply bloom filter, default 10
10461121 when chunk size of left and right is greater than this threshold, apply bloom filter
1122+ * "filter": "large", "small", "both", default "large"
1123+ decides to filter on large, small or both DataFrames.
10471124
10481125 Returns
10491126 -------
@@ -1153,4 +1230,5 @@ def join(
11531230 auto_merge = auto_merge ,
11541231 auto_merge_threshold = auto_merge_threshold ,
11551232 bloom_filter = bloom_filter ,
1233+ bloom_filter_options = bloom_filter_options ,
11561234 )
0 commit comments