@@ -30,41 +30,20 @@ def gen_test_data(dist_class, weights: bool = False):
3030 dmatrix (lgb.Dataset):
3131 DMatrix.
3232 """
33- if dist_class .dist .univariate :
34- np .random .seed (123 )
35- predt = np .random .rand (dist_class .dist .n_dist_param * 4 ).reshape (- 1 , dist_class .dist .n_dist_param )
36- labels = np .array ([0.2 , 0.4 , 0.6 , 0.8 ]).reshape (- 1 , 1 )
37- if weights :
38- weights = np .ones_like (labels )
39- dmatrix = lgb .Dataset (predt , label = labels , weight = weights )
40- dist_class .set_init_score (dmatrix )
41-
42- return predt , labels , weights , dmatrix
43- else :
44- dmatrix = lgb .Dataset (predt , label = labels )
45- dist_class .set_init_score (dmatrix )
46-
47- return predt , labels , dmatrix
33+ np .random .seed (123 )
34+ predt = np .random .rand (dist_class .dist .n_dist_param * 4 ).reshape (- 1 , dist_class .dist .n_dist_param )
35+ labels = np .array ([0.2 , 0.4 , 0.6 , 0.8 ]).reshape (- 1 , 1 )
36+ if weights :
37+ weights = np .ones_like (labels )
38+ dmatrix = lgb .Dataset (predt , label = labels , weight = weights )
39+ dist_class .set_init_score (dmatrix )
40+
41+ return predt , labels , weights , dmatrix
4842 else :
49- np .random .seed (123 )
50- predt = np .random .rand (dist_class .dist .n_dist_param * 4 ).reshape (- 1 , dist_class .dist .n_dist_param )
51- labels = np .arange (0.1 , 0.9 , 0.1 )
52- labels = dist_class .dist .target_append (
53- labels ,
54- dist_class .dist .n_targets ,
55- dist_class .dist .n_dist_param
56- )
57- if weights :
58- weights = np .ones_like (labels [:, 0 ], dtype = labels .dtype ).reshape (- 1 , 1 )
59- dmatrix = lgb .Dataset (predt , label = labels , weight = weights )
60- dist_class .set_init_score (dmatrix )
61-
62- return predt , labels , weights , dmatrix
63- else :
64- dmatrix = lgb .Dataset (predt , label = labels )
65- dist_class .set_init_score (dmatrix )
43+ dmatrix = lgb .Dataset (predt , label = labels )
44+ dist_class .set_init_score (dmatrix )
6645
67- return predt , labels , dmatrix
46+ return predt , labels , dmatrix
6847
6948
7049def get_distribution_classes (univariate : bool = True ,
@@ -128,18 +107,6 @@ def get_distribution_classes(univariate: bool = True,
128107 if distribution_class ().univariate and distribution_class ().discrete :
129108 univar_discrete_distns .append (distribution_class )
130109
131- # Extract all multivariate distributions
132- multivar_distns = []
133- for distribution_name in distns :
134- # Import the module dynamically
135- module = importlib .import_module (f"lightgbmlss.distributions.{ distribution_name } " )
136-
137- # Get the class dynamically from the module
138- distribution_class = getattr (module , distribution_name )
139-
140- if not distribution_class ().univariate :
141- multivar_distns .append (distribution_class )
142-
143110 # Extract distributions only that have a rsample method
144111 rsample_distns = []
145112 for distribution_name in distns :
@@ -178,9 +145,6 @@ def get_distribution_classes(univariate: bool = True,
178145 else :
179146 return univar_cont_distns
180147
181- elif not univariate and not flow and not expectile :
182- return multivar_distns
183-
184148 elif flow :
185149 distribution_name = "SplineFlow"
186150 module = importlib .import_module (f"lightgbmlss.distributions.{ distribution_name } " )
@@ -207,10 +171,6 @@ def univariate_cont_dist(self, request):
207171 def univariate_discrete_dist (self , request ):
208172 return request .param
209173
210- @pytest .fixture (params = get_distribution_classes (univariate = False ))
211- def multivariate_dist (self , request ):
212- return request .param
213-
214174 @pytest .fixture (params = get_distribution_classes (flow = True ))
215175 def flow_dist (self , request ):
216176 return request .param
@@ -219,24 +179,20 @@ def flow_dist(self, request):
219179 def expectile_dist (self , request ):
220180 return request .param
221181
222- @pytest .fixture (params =
223- get_distribution_classes () +
224- get_distribution_classes (discrete = True ) +
225- get_distribution_classes (expectile = True ) +
226- get_distribution_classes (flow = True ) +
227- get_distribution_classes (univariate = False )
228- )
182+ @pytest .fixture (
183+ params = get_distribution_classes () +
184+ get_distribution_classes (discrete = True ) +
185+ get_distribution_classes (expectile = True ) +
186+ get_distribution_classes (flow = True ) +
187+ get_distribution_classes (univariate = False )
188+ )
229189 def dist_class (self , request ):
230190 return LightGBMLSS (request .param ())
231191
232192 @pytest .fixture (params = get_distribution_classes (flow = True ))
233193 def flow_class (self , request ):
234194 return LightGBMLSS (request .param ())
235195
236- @pytest .fixture (params = get_distribution_classes (univariate = False ))
237- def multivariate_class (self , request ):
238- return LightGBMLSS (request .param ())
239-
240196 @pytest .fixture (params = get_distribution_classes (rsample = True ))
241197 def dist_class_crps (self , request ):
242198 return LightGBMLSS (request .param ())
0 commit comments