Skip to content

Commit 0f4db36

Browse files
authored
Bug fix for ProxySPEX, removed external library label in tests (#484)
1 parent 41d2811 commit 0f4db36

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

src/shapiq/approximator/sparse/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ def _refine(
482482
"""
483483
n = train_X.shape[1]
484484
four_items = list(four_dict.items())
485+
if len(four_items) <= self.n:
486+
return four_dict
485487
list_keys = [item[0] for item in four_items]
486488
four_coefs = np.array([item[1] for item in four_items])
487489

@@ -492,6 +494,8 @@ def _refine(
492494
four_coefs_for_energy[nfc_idx] = 0
493495
four_coefs_sq = four_coefs_for_energy**2
494496
tot_energy = np.sum(four_coefs_sq)
497+
if tot_energy == 0:
498+
return four_dict
495499
sorted_four_coefs_sq = np.sort(four_coefs_sq)[::-1]
496500
cumulative_energy_ratio = np.cumsum(sorted_four_coefs_sq / tot_energy)
497501
thresh_idx_95 = np.argmin(cumulative_energy_ratio < 0.95) + 1

tests/shapiq/tests_unit/tests_approximators/test_approximator_base_sparse.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def test_approximate_with_spex(
225225

226226

227227
@skip_if_no_lightgbm
228-
@pytest.mark.external_libraries
229228
def test_approximate_with_proxyspex():
230229
"""Tests the approximate method with the proxyspex decoder."""
231230
n = 10
@@ -263,3 +262,48 @@ def test_sparse_init_succeeds_lightgbm(monkeypatch):
263262

264263
with pytest.raises(ImportError, match="The 'lightgbm' package is required"):
265264
Sparse(n=10, index="SII", decoder_type="proxyspex")
265+
266+
267+
def test_refine_zero_total_energy():
268+
"""
269+
Tests that the _refine method handles the case when total energy is zero.
270+
271+
This test ensures that when all Fourier coefficients (excluding the baseline)
272+
are zero, the method returns the original dictionary without attempting to
273+
divide by zero.
274+
"""
275+
import numpy as np
276+
277+
n = 5
278+
# Use "soft" decoder which doesn't require lightgbm
279+
approximator = Sparse(n=n, index="FBII", max_order=2, decoder_type="soft", random_state=42)
280+
281+
# Create a four_dict where all non-baseline coefficients are zero
282+
four_dict = {
283+
(): 1.0, # baseline coefficient
284+
(0,): 0.0,
285+
(1,): 0.0,
286+
(0, 1): 0.0,
287+
(2,): 0.0,
288+
(1, 2): 0.0,
289+
}
290+
291+
# Create some dummy training data
292+
train_X = np.array(
293+
[
294+
[0, 0, 0, 0, 0],
295+
[1, 0, 0, 0, 0],
296+
[0, 1, 0, 0, 0],
297+
[1, 1, 0, 0, 0],
298+
[0, 0, 1, 0, 0],
299+
[1, 0, 1, 0, 0],
300+
]
301+
)
302+
train_y = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
303+
304+
# Call _refine - should not raise a ZeroDivisionError
305+
result = approximator._refine(four_dict, train_X, train_y)
306+
307+
# The result should be the same as the input when total energy is zero
308+
assert result == four_dict
309+
assert len(result) == len(four_dict)

tests/shapiq/tests_unit/tests_approximators/test_approximator_proxyspex.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
@skip_if_no_lightgbm
15-
@pytest.mark.external_libraries
1615
def test_initialization_defaults():
1716
"""Test that ProxySPEX initializes with correct defaults."""
1817
n = 10
@@ -36,7 +35,6 @@ def test_initialization_defaults():
3635
],
3736
)
3837
@skip_if_no_lightgbm
39-
@pytest.mark.external_libraries
4038
def test_initialization_custom(n, index, max_order, top_order):
4139
"""Test ProxySPEX initialization with custom parameters."""
4240
proxyspex = ProxySPEX(
@@ -61,7 +59,6 @@ def test_initialization_custom(n, index, max_order, top_order):
6159
],
6260
)
6361
@skip_if_no_lightgbm
64-
@pytest.mark.external_libraries
6562
def test_approximate(n, interactions, budget):
6663
"""Test ProxySPEX approximation functionality."""
6764

0 commit comments

Comments
 (0)