Skip to content

Commit ba8a6d9

Browse files
author
Xuye (Chris) Qin
authored
Do not aggressively choose tree method in tile of groupby for distributed setting (#3032)
1 parent acecc9c commit ba8a6d9

File tree

4 files changed

+105
-57
lines changed

4 files changed

+105
-57
lines changed

benchmarks/tpch/run_queries.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import argparse
1818
import functools
1919
import time
20-
from typing import Callable
20+
from typing import Callable, List, Optional, Set, Union
2121

2222
import mars
2323
import mars.dataframe as md
2424

25-
queries = None
25+
queries: Optional[Union[Set[str], List[str]]] = None
2626

2727

2828
def load_lineitem(data_folder: str) -> md.DataFrame:
@@ -158,7 +158,8 @@ def q01(lineitem: md.DataFrame):
158158
"L_ORDERKEY": "count",
159159
}
160160
)
161-
total = total.sort_values(["L_RETURNFLAG", "L_LINESTATUS"])
161+
# skip sort, Mars groupby enables sort
162+
# total = total.sort_values(["L_RETURNFLAG", "L_LINESTATUS"])
162163
print(total.execute())
163164

164165

@@ -238,7 +239,9 @@ def q02(part, partsupp, supplier, nation, region):
238239
"P_MFGR",
239240
],
240241
]
241-
min_values = merged_df.groupby("P_PARTKEY", as_index=False)["PS_SUPPLYCOST"].min()
242+
min_values = merged_df.groupby("P_PARTKEY", as_index=False, sort=False)[
243+
"PS_SUPPLYCOST"
244+
].min()
242245
min_values.columns = ["P_PARTKEY_CPY", "MIN_SUPPLYCOST"]
243246
merged_df = merged_df.merge(
244247
min_values,
@@ -286,9 +289,9 @@ def q03(lineitem, orders, customer):
286289
jn2 = jn1.merge(flineitem, left_on="O_ORDERKEY", right_on="L_ORDERKEY")
287290
jn2["TMP"] = jn2.L_EXTENDEDPRICE * (1 - jn2.L_DISCOUNT)
288291
total = (
289-
jn2.groupby(["L_ORDERKEY", "O_ORDERDATE", "O_SHIPPRIORITY"], as_index=False)[
290-
"TMP"
291-
]
292+
jn2.groupby(
293+
["L_ORDERKEY", "O_ORDERDATE", "O_SHIPPRIORITY"], as_index=False, sort=False
294+
)["TMP"]
292295
.sum()
293296
.sort_values(["TMP"], ascending=False)
294297
)
@@ -307,9 +310,9 @@ def q04(lineitem, orders):
307310
forders = orders[osel]
308311
jn = forders[forders["O_ORDERKEY"].isin(flineitem["L_ORDERKEY"])]
309312
total = (
310-
jn.groupby("O_ORDERPRIORITY", as_index=False)["O_ORDERKEY"]
311-
.count()
312-
.sort_values(["O_ORDERPRIORITY"])
313+
jn.groupby("O_ORDERPRIORITY", as_index=False)["O_ORDERKEY"].count()
314+
# skip sort when Mars enables sort in groupby
315+
# .sort_values(["O_ORDERPRIORITY"])
313316
)
314317
print(total.execute())
315318

@@ -330,7 +333,7 @@ def q05(lineitem, orders, customer, nation, region, supplier):
330333
jn4, left_on=["S_SUPPKEY", "S_NATIONKEY"], right_on=["L_SUPPKEY", "N_NATIONKEY"]
331334
)
332335
jn5["TMP"] = jn5.L_EXTENDEDPRICE * (1.0 - jn5.L_DISCOUNT)
333-
gb = jn5.groupby("N_NAME", as_index=False)["TMP"].sum()
336+
gb = jn5.groupby("N_NAME", as_index=False, sort=False)["TMP"].sum()
334337
total = gb.sort_values("TMP", ascending=False)
335338
print(total.execute())
336339

@@ -436,9 +439,10 @@ def q07(lineitem, supplier, orders, customer, nation):
436439
total = total.groupby(["SUPP_NATION", "CUST_NATION", "L_YEAR"], as_index=False).agg(
437440
REVENUE=md.NamedAgg(column="VOLUME", aggfunc="sum")
438441
)
439-
total = total.sort_values(
440-
by=["SUPP_NATION", "CUST_NATION", "L_YEAR"], ascending=[True, True, True]
441-
)
442+
# skip sort when Mars groupby does sort already
443+
# total = total.sort_values(
444+
# by=["SUPP_NATION", "CUST_NATION", "L_YEAR"], ascending=[True, True, True]
445+
# )
442446
print(total.execute())
443447

444448

@@ -520,7 +524,7 @@ def q09(lineitem, orders, part, nation, partsupp, supplier):
520524
(1 * jn5.PS_SUPPLYCOST) * jn5.L_QUANTITY
521525
)
522526
jn5["O_YEAR"] = jn5.O_ORDERDATE.dt.year
523-
gb = jn5.groupby(["N_NAME", "O_YEAR"], as_index=False)["TMP"].sum()
527+
gb = jn5.groupby(["N_NAME", "O_YEAR"], as_index=False, sort=False)["TMP"].sum()
524528
total = gb.sort_values(["N_NAME", "O_YEAR"], ascending=[True, False])
525529
print(total.execute())
526530

@@ -548,6 +552,7 @@ def q10(lineitem, orders, customer, nation):
548552
"C_COMMENT",
549553
],
550554
as_index=False,
555+
sort=False,
551556
)["TMP"].sum()
552557
total = gb.sort_values("TMP", ascending=False)
553558
print(total.head(20).execute())
@@ -571,7 +576,7 @@ def q11(partsupp, supplier, nation):
571576
)
572577
ps_supp_n_merge = ps_supp_n_merge.loc[:, ["PS_PARTKEY", "TOTAL_COST"]]
573578
sum_val = ps_supp_n_merge["TOTAL_COST"].sum() * 0.0001
574-
total = ps_supp_n_merge.groupby(["PS_PARTKEY"], as_index=False).agg(
579+
total = ps_supp_n_merge.groupby(["PS_PARTKEY"], as_index=False, sort=False).agg(
575580
VALUE=md.NamedAgg(column="TOTAL_COST", aggfunc="sum")
576581
)
577582
total = total[total["VALUE"] > sum_val]
@@ -603,7 +608,8 @@ def g2(x):
603608

604609
total = jn.groupby("L_SHIPMODE", as_index=False)["O_ORDERPRIORITY"].agg((g1, g2))
605610
total = total.reset_index() # reset index to keep consistency with pandas
606-
total = total.sort_values("L_SHIPMODE")
611+
# skip sort when groupby does sort already
612+
# total = total.sort_values("L_SHIPMODE")
607613
print(total.execute())
608614

609615

@@ -618,10 +624,10 @@ def q13(customer, orders):
618624
orders_filtered, left_on="C_CUSTKEY", right_on="O_CUSTKEY", how="left"
619625
)
620626
c_o_merged = c_o_merged.loc[:, ["C_CUSTKEY", "O_ORDERKEY"]]
621-
count_df = c_o_merged.groupby(["C_CUSTKEY"], as_index=False).agg(
627+
count_df = c_o_merged.groupby(["C_CUSTKEY"], as_index=False, sort=False).agg(
622628
C_COUNT=md.NamedAgg(column="O_ORDERKEY", aggfunc="count")
623629
)
624-
total = count_df.groupby(["C_COUNT"], as_index=False).size()
630+
total = count_df.groupby(["C_COUNT"], as_index=False, sort=False).size()
625631
total.columns = ["C_COUNT", "CUSTDIST"]
626632
total = total.sort_values(by=["CUSTDIST", "C_COUNT"], ascending=[False, False])
627633
print(total.execute())
@@ -660,7 +666,7 @@ def q15(lineitem, supplier):
660666
)
661667
lineitem_filtered = lineitem_filtered.loc[:, ["L_SUPPKEY", "REVENUE_PARTS"]]
662668
revenue_table = (
663-
lineitem_filtered.groupby("L_SUPPKEY", as_index=False)
669+
lineitem_filtered.groupby("L_SUPPKEY", as_index=False, sort=False)
664670
.agg(TOTAL_REVENUE=md.NamedAgg(column="REVENUE_PARTS", aggfunc="sum"))
665671
.rename(columns={"L_SUPPKEY": "SUPPLIER_NO"})
666672
)
@@ -699,7 +705,7 @@ def q16(part, partsupp, supplier):
699705
)
700706
total = total[total["S_SUPPKEY"].isna()]
701707
total = total.loc[:, ["P_BRAND", "P_TYPE", "P_SIZE", "PS_SUPPKEY"]]
702-
total = total.groupby(["P_BRAND", "P_TYPE", "P_SIZE"], as_index=False)[
708+
total = total.groupby(["P_BRAND", "P_TYPE", "P_SIZE"], as_index=False, sort=False)[
703709
"PS_SUPPKEY"
704710
].nunique()
705711
total.columns = ["P_BRAND", "P_TYPE", "P_SIZE", "SUPPLIER_CNT"]
@@ -722,9 +728,9 @@ def q17(lineitem, part):
722728
:, ["L_QUANTITY", "L_EXTENDEDPRICE", "P_PARTKEY"]
723729
]
724730
lineitem_filtered = lineitem.loc[:, ["L_PARTKEY", "L_QUANTITY"]]
725-
lineitem_avg = lineitem_filtered.groupby(["L_PARTKEY"], as_index=False).agg(
726-
avg=md.NamedAgg(column="L_QUANTITY", aggfunc="mean")
727-
)
731+
lineitem_avg = lineitem_filtered.groupby(
732+
["L_PARTKEY"], as_index=False, sort=False
733+
).agg(avg=md.NamedAgg(column="L_QUANTITY", aggfunc="mean"))
728734
lineitem_avg["avg"] = 0.2 * lineitem_avg["avg"]
729735
lineitem_avg = lineitem_avg.loc[:, ["L_PARTKEY", "avg"]]
730736
total = line_part_merge.merge(
@@ -737,13 +743,14 @@ def q17(lineitem, part):
737743

738744
@tpc_query
739745
def q18(lineitem, orders, customer):
740-
gb1 = lineitem.groupby("L_ORDERKEY", as_index=False)["L_QUANTITY"].sum()
746+
gb1 = lineitem.groupby("L_ORDERKEY", as_index=False, sort=False)["L_QUANTITY"].sum()
741747
fgb1 = gb1[gb1.L_QUANTITY > 300]
742748
jn1 = fgb1.merge(orders, left_on="L_ORDERKEY", right_on="O_ORDERKEY")
743749
jn2 = jn1.merge(customer, left_on="O_CUSTKEY", right_on="C_CUSTKEY")
744750
gb2 = jn2.groupby(
745751
["C_NAME", "C_CUSTKEY", "O_ORDERKEY", "O_ORDERDATE", "O_TOTALPRICE"],
746752
as_index=False,
753+
sort=False,
747754
)["L_QUANTITY"].sum()
748755
total = gb2.sort_values(["O_TOTALPRICE", "O_ORDERDATE"], ascending=[False, True])
749756
print(total.head(100).execute())
@@ -865,9 +872,9 @@ def q20(lineitem, part, nation, partsupp, supplier):
865872
left_on=["PS_PARTKEY", "PS_SUPPKEY"],
866873
right_on=["L_PARTKEY", "L_SUPPKEY"],
867874
)
868-
gb = jn2.groupby(["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY"], as_index=False)[
869-
"L_QUANTITY"
870-
].sum()
875+
gb = jn2.groupby(
876+
["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY"], as_index=False, sort=False
877+
)["L_QUANTITY"].sum()
871878
gbsel = gb.PS_AVAILQTY > (0.5 * gb.L_QUANTITY)
872879
fgb = gb[gbsel]
873880
jn3 = fgb.merge(supplier, left_on="PS_SUPPKEY", right_on="S_SUPPKEY")
@@ -886,7 +893,7 @@ def q21(lineitem, orders, supplier, nation):
886893
# Keep all rows that have another row in linetiem with the same orderkey and different suppkey
887894
lineitem_orderkeys = (
888895
lineitem_filtered.loc[:, ["L_ORDERKEY", "L_SUPPKEY"]]
889-
.groupby("L_ORDERKEY", as_index=False)["L_SUPPKEY"]
896+
.groupby("L_ORDERKEY", as_index=False, sort=False)["L_SUPPKEY"]
890897
.nunique()
891898
)
892899
lineitem_orderkeys.columns = ["L_ORDERKEY", "nunique_col"]
@@ -905,9 +912,9 @@ def q21(lineitem, orders, supplier, nation):
905912
)
906913

907914
# Not Exists: Check the exists condition isn't still satisfied on the output.
908-
lineitem_orderkeys = lineitem_filtered.groupby("L_ORDERKEY", as_index=False)[
909-
"L_SUPPKEY"
910-
].nunique()
915+
lineitem_orderkeys = lineitem_filtered.groupby(
916+
"L_ORDERKEY", as_index=False, sort=False
917+
)["L_SUPPKEY"].nunique()
911918
lineitem_orderkeys.columns = ["L_ORDERKEY", "nunique_col"]
912919
lineitem_orderkeys = lineitem_orderkeys[lineitem_orderkeys["nunique_col"] == 1]
913920
lineitem_orderkeys = lineitem_orderkeys.loc[:, ["L_ORDERKEY"]]
@@ -936,7 +943,7 @@ def q21(lineitem, orders, supplier, nation):
936943
nation_filtered, left_on="S_NATIONKEY", right_on="N_NATIONKEY", how="inner"
937944
)
938945
total = total.loc[:, ["S_NAME"]]
939-
total = total.groupby("S_NAME", as_index=False).size()
946+
total = total.groupby("S_NAME", as_index=False, sort=False).size()
940947
total.columns = ["S_NAME", "NUMWAIT"]
941948
total = total.sort_values(by=["NUMWAIT", "S_NAME"], ascending=[False, True])
942949
print(total.execute())
@@ -966,17 +973,21 @@ def q22(customer, orders):
966973
customer_filtered, on="C_CUSTKEY", how="inner"
967974
)
968975
customer_selected = customer_selected.loc[:, ["CNTRYCODE", "C_ACCTBAL"]]
969-
agg1 = customer_selected.groupby(["CNTRYCODE"], as_index=False).size()
976+
agg1 = customer_selected.groupby(["CNTRYCODE"], as_index=False, sort=False).size()
970977
agg1.columns = ["CNTRYCODE", "NUMCUST"]
971-
agg2 = customer_selected.groupby(["CNTRYCODE"], as_index=False).agg(
978+
agg2 = customer_selected.groupby(["CNTRYCODE"], as_index=False, sort=False).agg(
972979
TOTACCTBAL=md.NamedAgg(column="C_ACCTBAL", aggfunc="sum")
973980
)
974981
total = agg1.merge(agg2, on="CNTRYCODE", how="inner")
975982
total = total.sort_values(by=["CNTRYCODE"], ascending=[True])
976983
print(total.execute())
977984

978985

979-
def run_queries(data_folder: str):
986+
def run_queries(data_folder: str, select: List[str] = None):
987+
if select:
988+
global queries
989+
queries = select
990+
980991
# Load the data
981992
t1 = time.time()
982993
lineitem = load_lineitem(data_folder)

mars/dataframe/groupby/aggregation.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
from ... import opcodes as OperandDef
2626
from ...config import options
27-
from ...core.custom_log import redirect_custom_log
2827
from ...core import ENTITY_TYPE, OutputType
29-
from ...core.context import get_context
28+
from ...core.custom_log import redirect_custom_log
29+
from ...core.context import get_context, Context
3030
from ...core.operand import OperandStage
3131
from ...serialization.serializables import (
3232
Int32Field,
@@ -65,7 +65,8 @@
6565
cudf = lazy_import("cudf")
6666

6767
logger = logging.getLogger(__name__)
68-
68+
CV_THRESHOLD = 0.2
69+
MEAN_RATIO_THRESHOLD = 2 / 3
6970
_support_get_group_without_as_index = pd_release_version[:2] > (1, 0)
7071

7172

@@ -783,11 +784,36 @@ def _combine_tree(
783784

784785
@classmethod
785786
def _choose_tree_method(
786-
cls, raw_sizes, agg_sizes, sample_count, total_count, chunk_store_limit
787-
):
787+
cls,
788+
raw_sizes: List[int],
789+
agg_sizes: List[int],
790+
sample_count: int,
791+
total_count: int,
792+
chunk_store_limit: int,
793+
ctx: Context,
794+
) -> bool:
795+
logger.debug(
796+
"Start to choose method for Groupby, agg_sizes: %s, raw_sizes: %s, "
797+
"sample_count: %s, total_count: %s, chunk_store_limit: %s",
798+
agg_sizes,
799+
raw_sizes,
800+
sample_count,
801+
total_count,
802+
chunk_store_limit,
803+
)
804+
estimate_size = sum(agg_sizes) / sample_count * total_count
805+
if (
806+
len(ctx.get_worker_addresses()) > 1
807+
and estimate_size > chunk_store_limit
808+
and np.mean(agg_sizes) > 1024**2
809+
):
810+
# for distributed, if estimate size could be potentially large,
811+
# and each chunk size is large enough(>1M, small chunk means large error),
812+
# we choose to use shuffle
813+
return False
788814
# calculate the coefficient of variation of aggregation sizes,
789-
# if the CV is less than 0.2 and the mean of agg_size/raw_size
790-
# is less than 0.8, we suppose the single chunk's aggregation size
815+
# if the CV is less than CV_THRESHOLD and the mean of agg_size/raw_size
816+
# is less than MEAN_RATIO_THRESHOLD, we suppose the single chunk's aggregation size
791817
# almost equals to the tileable's, then use tree method
792818
# as combine aggregation results won't lead to a rapid expansion.
793819
ratios = [
@@ -796,12 +822,11 @@ def _choose_tree_method(
796822
cv = variation(agg_sizes)
797823
mean_ratio = np.mean(ratios)
798824
if mean_ratio <= 1 / sample_count:
799-
# if mean of ratio is less than 0.25, use tree
800825
return True
801-
if cv <= 0.2 and mean_ratio <= 2 / 3:
826+
if cv <= CV_THRESHOLD and mean_ratio <= MEAN_RATIO_THRESHOLD:
802827
# check CV and mean of ratio
803828
return True
804-
elif sum(agg_sizes) / sample_count * total_count <= chunk_store_limit:
829+
if estimate_size <= chunk_store_limit:
805830
# if estimated size less than `chunk_store_limit`, use tree.
806831
return True
807832
return False
@@ -835,9 +860,14 @@ def _tile_auto(
835860
left_chunks = in_df.chunks[combine_size:]
836861
left_chunks = cls._gen_map_chunks(op, left_chunks, out_df, func_infos)
837862
if cls._choose_tree_method(
838-
raw_sizes, agg_sizes, len(chunks), len(in_df.chunks), op.chunk_store_limit
863+
raw_sizes,
864+
agg_sizes,
865+
len(chunks),
866+
len(in_df.chunks),
867+
op.chunk_store_limit,
868+
ctx,
839869
):
840-
logger.debug("Choose tree method for groupby operand %s", op)
870+
logger.info("Choose tree method for groupby operand %s", op)
841871
return cls._combine_tree(op, chunks + left_chunks, out_df, func_infos)
842872
else:
843873
# otherwise, use shuffle
@@ -847,7 +877,7 @@ def _tile_auto(
847877
sample_chunks = cls._sample_chunks(op, chunks + left_chunks)
848878
pivot_chunk = cls._gen_pivot_chunk(op, sample_chunks, agg_chunk_len)
849879

850-
logger.debug("Choose shuffle method for groupby operand %s", op)
880+
logger.info("Choose shuffle method for groupby operand %s", op)
851881
return cls._perform_shuffle(
852882
op, chunks + left_chunks, in_df, out_df, func_infos, pivot_chunk
853883
)

mars/dataframe/groupby/sort.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,12 @@
1818
from ... import opcodes as OperandDef
1919
from ...core import OutputType
2020
from ...core.operand import MapReduceOperand, OperandStage
21-
from ...serialization.serializables import (
22-
Int32Field,
23-
ListField,
24-
)
25-
from ...utils import (
26-
lazy_import,
27-
)
21+
from ...serialization.serializables import Int32Field, ListField
22+
from ...utils import lazy_import
2823
from ..operands import DataFrameOperandMixin
2924
from ..sort.psrs import DataFramePSRSChunkOperand
3025

31-
cudf = lazy_import("cudf", globals=globals())
26+
cudf = lazy_import("cudf")
3227

3328

3429
def _series_to_df(in_series, xdf):

mars/dataframe/groupby/tests/test_groupby_execution.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,18 @@ def _disallow_combine_and_agg(ctx, op):
722722
pd.testing.assert_frame_equal(result.sort_index(), raw.groupby("c1").agg("sum"))
723723

724724

725+
def test_distributed_groupby_agg(setup_cluster):
726+
rs = np.random.RandomState(0)
727+
raw = pd.DataFrame(rs.rand(50000, 10))
728+
df = md.DataFrame(raw, chunk_size=raw.shape[0] // 2)
729+
with option_context({"chunk_store_limit": 1024**2}):
730+
r = df.groupby(0).sum(combine_size=1)
731+
result = r.execute().fetch()
732+
pd.testing.assert_frame_equal(result, raw.groupby(0).sum())
733+
# test use shuffle
734+
assert len(r._fetch_infos()["memory_size"]) > 1
735+
736+
725737
def test_groupby_agg_str_cat(setup):
726738
agg_fun = lambda x: x.str.cat(sep="_", na_rep="NA")
727739

0 commit comments

Comments
 (0)