@@ -61,3 +61,166 @@ impl<S: CostModelStorageManager> CostModelImpl<S> {
6161 }
6262 }
6363}
64+
65+ #[ cfg( test) ]
66+ mod tests {
67+ use std:: collections:: HashMap ;
68+
69+ use crate :: {
70+ common:: { predicates:: constant_pred:: ConstantType , types:: TableId , values:: Value } ,
71+ cost_model:: tests:: {
72+ attr_ref, cnst, create_cost_model_mock_storage, empty_list, empty_per_attr_stats, list,
73+ TestPerAttributeStats ,
74+ } ,
75+ stats:: { utilities:: simple_map:: SimpleMap , MostCommonValues , DEFAULT_NUM_DISTINCT } ,
76+ storage:: Attribute ,
77+ EstimatedStatistic ,
78+ } ;
79+
80+ #[ tokio:: test]
81+ async fn test_agg_no_stats ( ) {
82+ let table_id = TableId ( 0 ) ;
83+ let attr_infos = HashMap :: from ( [ (
84+ table_id,
85+ HashMap :: from ( [
86+ (
87+ 0 ,
88+ Attribute {
89+ name : String :: from ( "attr1" ) ,
90+ typ : ConstantType :: Int32 ,
91+ nullable : false ,
92+ } ,
93+ ) ,
94+ (
95+ 1 ,
96+ Attribute {
97+ name : String :: from ( "attr2" ) ,
98+ typ : ConstantType :: Int64 ,
99+ nullable : false ,
100+ } ,
101+ ) ,
102+ ] ) ,
103+ ) ] ) ;
104+ let cost_model =
105+ create_cost_model_mock_storage ( vec ! [ table_id] , vec ! [ ] , vec ! [ None ] , attr_infos) ;
106+
107+ // Group by empty list should return 1.
108+ let group_bys = empty_list ( ) ;
109+ assert_eq ! (
110+ cost_model. get_agg_row_cnt( group_bys) . await . unwrap( ) ,
111+ EstimatedStatistic ( 1 )
112+ ) ;
113+
114+ // Group by single column should return the default value since there are no stats.
115+ let group_bys = list ( vec ! [ attr_ref( table_id, 0 ) ] ) ;
116+ assert_eq ! (
117+ cost_model. get_agg_row_cnt( group_bys) . await . unwrap( ) ,
118+ EstimatedStatistic ( DEFAULT_NUM_DISTINCT )
119+ ) ;
120+
121+ // Group by two columns should return the default value squared since there are no stats.
122+ let group_bys = list ( vec ! [ attr_ref( table_id, 0 ) , attr_ref( table_id, 1 ) ] ) ;
123+ assert_eq ! (
124+ cost_model. get_agg_row_cnt( group_bys) . await . unwrap( ) ,
125+ EstimatedStatistic ( DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT )
126+ ) ;
127+ }
128+
129+ #[ tokio:: test]
130+ async fn test_agg_with_stats ( ) {
131+ let table_id = TableId ( 0 ) ;
132+ let attr1_base_idx = 0 ;
133+ let attr2_base_idx = 1 ;
134+ let attr3_base_idx = 2 ;
135+ let attr_infos = HashMap :: from ( [ (
136+ table_id,
137+ HashMap :: from ( [
138+ (
139+ attr1_base_idx,
140+ Attribute {
141+ name : String :: from ( "attr1" ) ,
142+ typ : ConstantType :: Int32 ,
143+ nullable : false ,
144+ } ,
145+ ) ,
146+ (
147+ attr2_base_idx,
148+ Attribute {
149+ name : String :: from ( "attr2" ) ,
150+ typ : ConstantType :: Int64 ,
151+ nullable : false ,
152+ } ,
153+ ) ,
154+ (
155+ attr3_base_idx,
156+ Attribute {
157+ name : String :: from ( "attr3" ) ,
158+ typ : ConstantType :: Int64 ,
159+ nullable : false ,
160+ } ,
161+ ) ,
162+ ] ) ,
163+ ) ] ) ;
164+
165+ let attr1_ndistinct = 12 ;
166+ let attr2_ndistinct = 645 ;
167+ let attr1_stats = TestPerAttributeStats :: new (
168+ MostCommonValues :: SimpleFrequency ( SimpleMap :: default ( ) ) ,
169+ None ,
170+ attr1_ndistinct,
171+ 0.0 ,
172+ ) ;
173+ let attr2_stats = TestPerAttributeStats :: new (
174+ MostCommonValues :: SimpleFrequency ( SimpleMap :: default ( ) ) ,
175+ None ,
176+ attr2_ndistinct,
177+ 0.0 ,
178+ ) ;
179+
180+ let cost_model = create_cost_model_mock_storage (
181+ vec ! [ table_id] ,
182+ vec ! [ HashMap :: from( [
183+ ( attr1_base_idx, attr1_stats) ,
184+ ( attr2_base_idx, attr2_stats) ,
185+ ] ) ] ,
186+ vec ! [ None ] ,
187+ attr_infos,
188+ ) ;
189+
190+ // Group by empty list should return 1.
191+ let group_bys = empty_list ( ) ;
192+ assert_eq ! (
193+ cost_model. get_agg_row_cnt( group_bys) . await . unwrap( ) ,
194+ EstimatedStatistic ( 1 )
195+ ) ;
196+
197+ // Group by single column should return the n-distinct of the column.
198+ let group_bys = list ( vec ! [ attr_ref( table_id, attr1_base_idx) ] ) ;
199+ assert_eq ! (
200+ cost_model. get_agg_row_cnt( group_bys) . await . unwrap( ) ,
201+ EstimatedStatistic ( attr1_ndistinct)
202+ ) ;
203+
204+ // Group by two columns should return the product of the n-distinct of the columns.
205+ let group_bys = list ( vec ! [
206+ attr_ref( table_id, attr1_base_idx) ,
207+ attr_ref( table_id, attr2_base_idx) ,
208+ ] ) ;
209+ assert_eq ! (
210+ cost_model. get_agg_row_cnt( group_bys) . await . unwrap( ) ,
211+ EstimatedStatistic ( attr1_ndistinct * attr2_ndistinct)
212+ ) ;
213+
214+ // Group by multiple columns should return the product of the n-distinct of the columns. If one of the columns
215+ // does not have stats, it should use the default value instead.
216+ let group_bys = list ( vec ! [
217+ attr_ref( table_id, attr1_base_idx) ,
218+ attr_ref( table_id, attr2_base_idx) ,
219+ attr_ref( table_id, attr3_base_idx) ,
220+ ] ) ;
221+ assert_eq ! (
222+ cost_model. get_agg_row_cnt( group_bys) . await . unwrap( ) ,
223+ EstimatedStatistic ( attr1_ndistinct * attr2_ndistinct * DEFAULT_NUM_DISTINCT )
224+ ) ;
225+ }
226+ }
0 commit comments