1717import argparse
1818import functools
1919import time
20- from typing import Callable
20+ from typing import Callable , List , Optional , Set , Union
2121
2222import mars
2323import mars .dataframe as md
2424
25- queries = None
25+ queries : Optional [ Union [ Set [ str ], List [ str ]]] = None
2626
2727
2828def load_lineitem (data_folder : str ) -> md .DataFrame :
@@ -158,7 +158,8 @@ def q01(lineitem: md.DataFrame):
158158 "L_ORDERKEY" : "count" ,
159159 }
160160 )
161- total = total .sort_values (["L_RETURNFLAG" , "L_LINESTATUS" ])
161+ # skip sort, Mars groupby enables sort
162+ # total = total.sort_values(["L_RETURNFLAG", "L_LINESTATUS"])
162163 print (total .execute ())
163164
164165
@@ -238,7 +239,9 @@ def q02(part, partsupp, supplier, nation, region):
238239 "P_MFGR" ,
239240 ],
240241 ]
241- min_values = merged_df .groupby ("P_PARTKEY" , as_index = False )["PS_SUPPLYCOST" ].min ()
242+ min_values = merged_df .groupby ("P_PARTKEY" , as_index = False , sort = False )[
243+ "PS_SUPPLYCOST"
244+ ].min ()
242245 min_values .columns = ["P_PARTKEY_CPY" , "MIN_SUPPLYCOST" ]
243246 merged_df = merged_df .merge (
244247 min_values ,
@@ -286,9 +289,9 @@ def q03(lineitem, orders, customer):
286289 jn2 = jn1 .merge (flineitem , left_on = "O_ORDERKEY" , right_on = "L_ORDERKEY" )
287290 jn2 ["TMP" ] = jn2 .L_EXTENDEDPRICE * (1 - jn2 .L_DISCOUNT )
288291 total = (
289- jn2 .groupby ([ "L_ORDERKEY" , "O_ORDERDATE" , "O_SHIPPRIORITY" ], as_index = False )[
290- "TMP"
291- ]
292+ jn2 .groupby (
293+ [ "L_ORDERKEY" , "O_ORDERDATE" , "O_SHIPPRIORITY" ], as_index = False , sort = False
294+ )[ "TMP" ]
292295 .sum ()
293296 .sort_values (["TMP" ], ascending = False )
294297 )
@@ -307,9 +310,9 @@ def q04(lineitem, orders):
307310 forders = orders [osel ]
308311 jn = forders [forders ["O_ORDERKEY" ].isin (flineitem ["L_ORDERKEY" ])]
309312 total = (
310- jn .groupby ("O_ORDERPRIORITY" , as_index = False )["O_ORDERKEY" ]
311- . count ()
312- .sort_values (["O_ORDERPRIORITY" ])
313+ jn .groupby ("O_ORDERPRIORITY" , as_index = False )["O_ORDERKEY" ]. count ()
314+ # skip sort when Mars enables sort in groupby
315+ # .sort_values(["O_ORDERPRIORITY"])
313316 )
314317 print (total .execute ())
315318
@@ -330,7 +333,7 @@ def q05(lineitem, orders, customer, nation, region, supplier):
330333 jn4 , left_on = ["S_SUPPKEY" , "S_NATIONKEY" ], right_on = ["L_SUPPKEY" , "N_NATIONKEY" ]
331334 )
332335 jn5 ["TMP" ] = jn5 .L_EXTENDEDPRICE * (1.0 - jn5 .L_DISCOUNT )
333- gb = jn5 .groupby ("N_NAME" , as_index = False )["TMP" ].sum ()
336+ gb = jn5 .groupby ("N_NAME" , as_index = False , sort = False )["TMP" ].sum ()
334337 total = gb .sort_values ("TMP" , ascending = False )
335338 print (total .execute ())
336339
@@ -436,9 +439,10 @@ def q07(lineitem, supplier, orders, customer, nation):
436439 total = total .groupby (["SUPP_NATION" , "CUST_NATION" , "L_YEAR" ], as_index = False ).agg (
437440 REVENUE = md .NamedAgg (column = "VOLUME" , aggfunc = "sum" )
438441 )
439- total = total .sort_values (
440- by = ["SUPP_NATION" , "CUST_NATION" , "L_YEAR" ], ascending = [True , True , True ]
441- )
442+ # skip sort when Mars groupby does sort already
443+ # total = total.sort_values(
444+ # by=["SUPP_NATION", "CUST_NATION", "L_YEAR"], ascending=[True, True, True]
445+ # )
442446 print (total .execute ())
443447
444448
@@ -520,7 +524,7 @@ def q09(lineitem, orders, part, nation, partsupp, supplier):
520524 (1 * jn5 .PS_SUPPLYCOST ) * jn5 .L_QUANTITY
521525 )
522526 jn5 ["O_YEAR" ] = jn5 .O_ORDERDATE .dt .year
523- gb = jn5 .groupby (["N_NAME" , "O_YEAR" ], as_index = False )["TMP" ].sum ()
527+ gb = jn5 .groupby (["N_NAME" , "O_YEAR" ], as_index = False , sort = False )["TMP" ].sum ()
524528 total = gb .sort_values (["N_NAME" , "O_YEAR" ], ascending = [True , False ])
525529 print (total .execute ())
526530
@@ -548,6 +552,7 @@ def q10(lineitem, orders, customer, nation):
548552 "C_COMMENT" ,
549553 ],
550554 as_index = False ,
555+ sort = False ,
551556 )["TMP" ].sum ()
552557 total = gb .sort_values ("TMP" , ascending = False )
553558 print (total .head (20 ).execute ())
@@ -571,7 +576,7 @@ def q11(partsupp, supplier, nation):
571576 )
572577 ps_supp_n_merge = ps_supp_n_merge .loc [:, ["PS_PARTKEY" , "TOTAL_COST" ]]
573578 sum_val = ps_supp_n_merge ["TOTAL_COST" ].sum () * 0.0001
574- total = ps_supp_n_merge .groupby (["PS_PARTKEY" ], as_index = False ).agg (
579+ total = ps_supp_n_merge .groupby (["PS_PARTKEY" ], as_index = False , sort = False ).agg (
575580 VALUE = md .NamedAgg (column = "TOTAL_COST" , aggfunc = "sum" )
576581 )
577582 total = total [total ["VALUE" ] > sum_val ]
@@ -603,7 +608,8 @@ def g2(x):
603608
604609 total = jn .groupby ("L_SHIPMODE" , as_index = False )["O_ORDERPRIORITY" ].agg ((g1 , g2 ))
605610 total = total .reset_index () # reset index to keep consistency with pandas
606- total = total .sort_values ("L_SHIPMODE" )
611+ # skip sort when groupby does sort already
612+ # total = total.sort_values("L_SHIPMODE")
607613 print (total .execute ())
608614
609615
@@ -618,10 +624,10 @@ def q13(customer, orders):
618624 orders_filtered , left_on = "C_CUSTKEY" , right_on = "O_CUSTKEY" , how = "left"
619625 )
620626 c_o_merged = c_o_merged .loc [:, ["C_CUSTKEY" , "O_ORDERKEY" ]]
621- count_df = c_o_merged .groupby (["C_CUSTKEY" ], as_index = False ).agg (
627+ count_df = c_o_merged .groupby (["C_CUSTKEY" ], as_index = False , sort = False ).agg (
622628 C_COUNT = md .NamedAgg (column = "O_ORDERKEY" , aggfunc = "count" )
623629 )
624- total = count_df .groupby (["C_COUNT" ], as_index = False ).size ()
630+ total = count_df .groupby (["C_COUNT" ], as_index = False , sort = False ).size ()
625631 total .columns = ["C_COUNT" , "CUSTDIST" ]
626632 total = total .sort_values (by = ["CUSTDIST" , "C_COUNT" ], ascending = [False , False ])
627633 print (total .execute ())
@@ -660,7 +666,7 @@ def q15(lineitem, supplier):
660666 )
661667 lineitem_filtered = lineitem_filtered .loc [:, ["L_SUPPKEY" , "REVENUE_PARTS" ]]
662668 revenue_table = (
663- lineitem_filtered .groupby ("L_SUPPKEY" , as_index = False )
669+ lineitem_filtered .groupby ("L_SUPPKEY" , as_index = False , sort = False )
664670 .agg (TOTAL_REVENUE = md .NamedAgg (column = "REVENUE_PARTS" , aggfunc = "sum" ))
665671 .rename (columns = {"L_SUPPKEY" : "SUPPLIER_NO" })
666672 )
@@ -699,7 +705,7 @@ def q16(part, partsupp, supplier):
699705 )
700706 total = total [total ["S_SUPPKEY" ].isna ()]
701707 total = total .loc [:, ["P_BRAND" , "P_TYPE" , "P_SIZE" , "PS_SUPPKEY" ]]
702- total = total .groupby (["P_BRAND" , "P_TYPE" , "P_SIZE" ], as_index = False )[
708+ total = total .groupby (["P_BRAND" , "P_TYPE" , "P_SIZE" ], as_index = False , sort = False )[
703709 "PS_SUPPKEY"
704710 ].nunique ()
705711 total .columns = ["P_BRAND" , "P_TYPE" , "P_SIZE" , "SUPPLIER_CNT" ]
@@ -722,9 +728,9 @@ def q17(lineitem, part):
722728 :, ["L_QUANTITY" , "L_EXTENDEDPRICE" , "P_PARTKEY" ]
723729 ]
724730 lineitem_filtered = lineitem .loc [:, ["L_PARTKEY" , "L_QUANTITY" ]]
725- lineitem_avg = lineitem_filtered .groupby ([ "L_PARTKEY" ], as_index = False ). agg (
726- avg = md . NamedAgg ( column = "L_QUANTITY" , aggfunc = "mean" )
727- )
731+ lineitem_avg = lineitem_filtered .groupby (
732+ [ "L_PARTKEY" ], as_index = False , sort = False
733+ ). agg ( avg = md . NamedAgg ( column = "L_QUANTITY" , aggfunc = "mean" ))
728734 lineitem_avg ["avg" ] = 0.2 * lineitem_avg ["avg" ]
729735 lineitem_avg = lineitem_avg .loc [:, ["L_PARTKEY" , "avg" ]]
730736 total = line_part_merge .merge (
@@ -737,13 +743,14 @@ def q17(lineitem, part):
737743
738744@tpc_query
739745def q18 (lineitem , orders , customer ):
740- gb1 = lineitem .groupby ("L_ORDERKEY" , as_index = False )["L_QUANTITY" ].sum ()
746+ gb1 = lineitem .groupby ("L_ORDERKEY" , as_index = False , sort = False )["L_QUANTITY" ].sum ()
741747 fgb1 = gb1 [gb1 .L_QUANTITY > 300 ]
742748 jn1 = fgb1 .merge (orders , left_on = "L_ORDERKEY" , right_on = "O_ORDERKEY" )
743749 jn2 = jn1 .merge (customer , left_on = "O_CUSTKEY" , right_on = "C_CUSTKEY" )
744750 gb2 = jn2 .groupby (
745751 ["C_NAME" , "C_CUSTKEY" , "O_ORDERKEY" , "O_ORDERDATE" , "O_TOTALPRICE" ],
746752 as_index = False ,
753+ sort = False ,
747754 )["L_QUANTITY" ].sum ()
748755 total = gb2 .sort_values (["O_TOTALPRICE" , "O_ORDERDATE" ], ascending = [False , True ])
749756 print (total .head (100 ).execute ())
@@ -865,9 +872,9 @@ def q20(lineitem, part, nation, partsupp, supplier):
865872 left_on = ["PS_PARTKEY" , "PS_SUPPKEY" ],
866873 right_on = ["L_PARTKEY" , "L_SUPPKEY" ],
867874 )
868- gb = jn2 .groupby ([ "PS_PARTKEY" , "PS_SUPPKEY" , "PS_AVAILQTY" ], as_index = False )[
869- "L_QUANTITY"
870- ].sum ()
875+ gb = jn2 .groupby (
876+ [ "PS_PARTKEY" , "PS_SUPPKEY" , "PS_AVAILQTY" ], as_index = False , sort = False
877+ )[ "L_QUANTITY" ].sum ()
871878 gbsel = gb .PS_AVAILQTY > (0.5 * gb .L_QUANTITY )
872879 fgb = gb [gbsel ]
873880 jn3 = fgb .merge (supplier , left_on = "PS_SUPPKEY" , right_on = "S_SUPPKEY" )
@@ -886,7 +893,7 @@ def q21(lineitem, orders, supplier, nation):
886893 # Keep all rows that have another row in linetiem with the same orderkey and different suppkey
887894 lineitem_orderkeys = (
888895 lineitem_filtered .loc [:, ["L_ORDERKEY" , "L_SUPPKEY" ]]
889- .groupby ("L_ORDERKEY" , as_index = False )["L_SUPPKEY" ]
896+ .groupby ("L_ORDERKEY" , as_index = False , sort = False )["L_SUPPKEY" ]
890897 .nunique ()
891898 )
892899 lineitem_orderkeys .columns = ["L_ORDERKEY" , "nunique_col" ]
@@ -905,9 +912,9 @@ def q21(lineitem, orders, supplier, nation):
905912 )
906913
907914 # Not Exists: Check the exists condition isn't still satisfied on the output.
908- lineitem_orderkeys = lineitem_filtered .groupby ("L_ORDERKEY" , as_index = False )[
909- "L_SUPPKEY"
910- ].nunique ()
915+ lineitem_orderkeys = lineitem_filtered .groupby (
916+ "L_ORDERKEY" , as_index = False , sort = False
917+ )[ "L_SUPPKEY" ].nunique ()
911918 lineitem_orderkeys .columns = ["L_ORDERKEY" , "nunique_col" ]
912919 lineitem_orderkeys = lineitem_orderkeys [lineitem_orderkeys ["nunique_col" ] == 1 ]
913920 lineitem_orderkeys = lineitem_orderkeys .loc [:, ["L_ORDERKEY" ]]
@@ -936,7 +943,7 @@ def q21(lineitem, orders, supplier, nation):
936943 nation_filtered , left_on = "S_NATIONKEY" , right_on = "N_NATIONKEY" , how = "inner"
937944 )
938945 total = total .loc [:, ["S_NAME" ]]
939- total = total .groupby ("S_NAME" , as_index = False ).size ()
946+ total = total .groupby ("S_NAME" , as_index = False , sort = False ).size ()
940947 total .columns = ["S_NAME" , "NUMWAIT" ]
941948 total = total .sort_values (by = ["NUMWAIT" , "S_NAME" ], ascending = [False , True ])
942949 print (total .execute ())
@@ -966,17 +973,21 @@ def q22(customer, orders):
966973 customer_filtered , on = "C_CUSTKEY" , how = "inner"
967974 )
968975 customer_selected = customer_selected .loc [:, ["CNTRYCODE" , "C_ACCTBAL" ]]
969- agg1 = customer_selected .groupby (["CNTRYCODE" ], as_index = False ).size ()
976+ agg1 = customer_selected .groupby (["CNTRYCODE" ], as_index = False , sort = False ).size ()
970977 agg1 .columns = ["CNTRYCODE" , "NUMCUST" ]
971- agg2 = customer_selected .groupby (["CNTRYCODE" ], as_index = False ).agg (
978+ agg2 = customer_selected .groupby (["CNTRYCODE" ], as_index = False , sort = False ).agg (
972979 TOTACCTBAL = md .NamedAgg (column = "C_ACCTBAL" , aggfunc = "sum" )
973980 )
974981 total = agg1 .merge (agg2 , on = "CNTRYCODE" , how = "inner" )
975982 total = total .sort_values (by = ["CNTRYCODE" ], ascending = [True ])
976983 print (total .execute ())
977984
978985
979- def run_queries (data_folder : str ):
986+ def run_queries (data_folder : str , select : List [str ] = None ):
987+ if select :
988+ global queries
989+ queries = select
990+
980991 # Load the data
981992 t1 = time .time ()
982993 lineitem = load_lineitem (data_folder )
0 commit comments