Skip to content

Commit 469fda9

Browse files
author
Alexander März
committed
Updates
1 parent e477adc commit 469fda9

File tree

3 files changed

+23
-79
lines changed

3 files changed

+23
-79
lines changed

tests/test_distribution_utils/test_compute_gradients_and_hessians.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@ class TestClass(BaseTestClass):
88
def test_compute_gradients_and_hessians(self, dist_class, loss_fn, stabilization):
99
# Create data for testing
1010
params, target, weights, _ = gen_test_data(dist_class, weights=True)
11-
if dist_class.dist.univariate:
12-
target = torch.tensor(target)
13-
else:
14-
target = torch.tensor(target)[:, :dist_class.dist.n_targets]
11+
target = torch.tensor(target)
1512
start_values = np.array([0.5 for _ in range(dist_class.dist.n_dist_param)])
1613

1714
# Set the loss function for testing
@@ -44,10 +41,7 @@ def test_compute_gradients_and_hessians(self, dist_class, loss_fn, stabilization
4441
def test_compute_gradients_and_hessians_crps(self, dist_class_crps, stabilization):
4542
# Create data for testing
4643
params, target, weights, _ = gen_test_data(dist_class_crps, weights=True)
47-
if dist_class_crps.dist.univariate:
48-
target = torch.tensor(target)
49-
else:
50-
target = torch.tensor(target)[:, :dist_class_crps.dist.n_targets]
44+
target = torch.tensor(target)
5145
start_values = np.array([0.5 for _ in range(dist_class_crps.dist.n_dist_param)])
5246

5347
# Set the loss function for testing
@@ -81,10 +75,7 @@ def test_compute_gradients_and_hessians_nans(self, dist_class, loss_fn, stabiliz
8175
# Create data for testing
8276
params, target, weights, _ = gen_test_data(dist_class, weights=True)
8377
params[0, 0] = np.nan
84-
if dist_class.dist.univariate:
85-
target = torch.tensor(target)
86-
else:
87-
target = torch.tensor(target)[:, :dist_class.dist.n_targets]
78+
target = torch.tensor(target)
8879
start_values = np.array([0.5 for _ in range(dist_class.dist.n_dist_param)])
8980

9081
# Set the loss function for testing

tests/test_distribution_utils/test_loss_fn_start_values.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@ def test_loss_fn_start_values(self, dist_class, loss_fn):
1010
torch.tensor(0.5, dtype=torch.float64).reshape(-1, 1).requires_grad_(True) for _ in
1111
range(dist_class.dist.n_dist_param)
1212
]
13-
if dist_class.dist.univariate:
14-
target = torch.tensor(target)
15-
else:
16-
target = torch.tensor(target)[:, :dist_class.dist.n_targets]
13+
target = torch.tensor(target)
1714

1815
# Set the loss function for testing
1916
dist_class.dist.loss_fn = loss_fn

tests/utils.py

Lines changed: 19 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,41 +30,20 @@ def gen_test_data(dist_class, weights: bool = False):
3030
dmatrix (lgb.Dataset):
3131
DMatrix.
3232
"""
33-
if dist_class.dist.univariate:
34-
np.random.seed(123)
35-
predt = np.random.rand(dist_class.dist.n_dist_param * 4).reshape(-1, dist_class.dist.n_dist_param)
36-
labels = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
37-
if weights:
38-
weights = np.ones_like(labels)
39-
dmatrix = lgb.Dataset(predt, label=labels, weight=weights)
40-
dist_class.set_init_score(dmatrix)
41-
42-
return predt, labels, weights, dmatrix
43-
else:
44-
dmatrix = lgb.Dataset(predt, label=labels)
45-
dist_class.set_init_score(dmatrix)
46-
47-
return predt, labels, dmatrix
33+
np.random.seed(123)
34+
predt = np.random.rand(dist_class.dist.n_dist_param * 4).reshape(-1, dist_class.dist.n_dist_param)
35+
labels = np.array([0.2, 0.4, 0.6, 0.8]).reshape(-1, 1)
36+
if weights:
37+
weights = np.ones_like(labels)
38+
dmatrix = lgb.Dataset(predt, label=labels, weight=weights)
39+
dist_class.set_init_score(dmatrix)
40+
41+
return predt, labels, weights, dmatrix
4842
else:
49-
np.random.seed(123)
50-
predt = np.random.rand(dist_class.dist.n_dist_param * 4).reshape(-1, dist_class.dist.n_dist_param)
51-
labels = np.arange(0.1, 0.9, 0.1)
52-
labels = dist_class.dist.target_append(
53-
labels,
54-
dist_class.dist.n_targets,
55-
dist_class.dist.n_dist_param
56-
)
57-
if weights:
58-
weights = np.ones_like(labels[:, 0], dtype=labels.dtype).reshape(-1, 1)
59-
dmatrix = lgb.Dataset(predt, label=labels, weight=weights)
60-
dist_class.set_init_score(dmatrix)
61-
62-
return predt, labels, weights, dmatrix
63-
else:
64-
dmatrix = lgb.Dataset(predt, label=labels)
65-
dist_class.set_init_score(dmatrix)
43+
dmatrix = lgb.Dataset(predt, label=labels)
44+
dist_class.set_init_score(dmatrix)
6645

67-
return predt, labels, dmatrix
46+
return predt, labels, dmatrix
6847

6948

7049
def get_distribution_classes(univariate: bool = True,
@@ -128,18 +107,6 @@ def get_distribution_classes(univariate: bool = True,
128107
if distribution_class().univariate and distribution_class().discrete:
129108
univar_discrete_distns.append(distribution_class)
130109

131-
# Extract all multivariate distributions
132-
multivar_distns = []
133-
for distribution_name in distns:
134-
# Import the module dynamically
135-
module = importlib.import_module(f"lightgbmlss.distributions.{distribution_name}")
136-
137-
# Get the class dynamically from the module
138-
distribution_class = getattr(module, distribution_name)
139-
140-
if not distribution_class().univariate:
141-
multivar_distns.append(distribution_class)
142-
143110
# Extract distributions only that have a rsample method
144111
rsample_distns = []
145112
for distribution_name in distns:
@@ -178,9 +145,6 @@ def get_distribution_classes(univariate: bool = True,
178145
else:
179146
return univar_cont_distns
180147

181-
elif not univariate and not flow and not expectile:
182-
return multivar_distns
183-
184148
elif flow:
185149
distribution_name = "SplineFlow"
186150
module = importlib.import_module(f"lightgbmlss.distributions.{distribution_name}")
@@ -207,10 +171,6 @@ def univariate_cont_dist(self, request):
207171
def univariate_discrete_dist(self, request):
208172
return request.param
209173

210-
@pytest.fixture(params=get_distribution_classes(univariate=False))
211-
def multivariate_dist(self, request):
212-
return request.param
213-
214174
@pytest.fixture(params=get_distribution_classes(flow=True))
215175
def flow_dist(self, request):
216176
return request.param
@@ -219,24 +179,20 @@ def flow_dist(self, request):
219179
def expectile_dist(self, request):
220180
return request.param
221181

222-
@pytest.fixture(params=
223-
get_distribution_classes() +
224-
get_distribution_classes(discrete=True) +
225-
get_distribution_classes(expectile=True) +
226-
get_distribution_classes(flow=True) +
227-
get_distribution_classes(univariate=False)
228-
)
182+
@pytest.fixture(
183+
params=get_distribution_classes() +
184+
get_distribution_classes(discrete=True) +
185+
get_distribution_classes(expectile=True) +
186+
get_distribution_classes(flow=True) +
187+
get_distribution_classes(univariate=False)
188+
)
229189
def dist_class(self, request):
230190
return LightGBMLSS(request.param())
231191

232192
@pytest.fixture(params=get_distribution_classes(flow=True))
233193
def flow_class(self, request):
234194
return LightGBMLSS(request.param())
235195

236-
@pytest.fixture(params=get_distribution_classes(univariate=False))
237-
def multivariate_class(self, request):
238-
return LightGBMLSS(request.param())
239-
240196
@pytest.fixture(params=get_distribution_classes(rsample=True))
241197
def dist_class_crps(self, request):
242198
return LightGBMLSS(request.param())

0 commit comments

Comments
 (0)