Skip to content

Commit e249246

Browse files
authored
Hotfix/fix annealing schedule bug (#485)
## Changes * fix annealing schedule for corner-case (no quadratic terms) * add corresponding tests
2 parents 8b5c9c1 + 8ec5408 commit e249246

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

openjij/sampler/sa_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,9 @@ def geometric_ising_beta_schedule(
899899
dE_linear = (ising_interaction[:-1, -1].toarray() * random_spin[:-1])
900900

901901
dE_quad_abs = np.abs(dE_quad)
902-
rate_dE = np.max(np.abs(dE_linear[dE_quad_abs > THRESHOLD]) /(dE_quad_abs[dE_quad_abs > THRESHOLD].mean() + THRESHOLD))
902+
# Recalculate the maximum rate of change in energy if the absolute value of dE_quad is above the threshold
903+
if np.any(dE_quad_abs > THRESHOLD):
904+
rate_dE = np.max(np.abs(dE_linear[dE_quad_abs > THRESHOLD]) /(dE_quad_abs[dE_quad_abs > THRESHOLD].mean() + THRESHOLD))
903905

904906
dE = dE_quad
905907
dE_positive = dE[dE > THRESHOLD]

tests/test_sampler.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,54 @@ def test_error_handling(self):
278278
with self.assertRaises(TypeError):
279279
oj.SQASampler(trotter=10)
280280

281+
def test_simple_two_variable_cases(self):
282+
# Test case 1: (0,0): 1, (1,1): 1 with expected result [0,0]
283+
qubo1 = {(0, 0): 1, (1, 1): 1}
284+
285+
sampler = oj.SASampler()
286+
res = sampler.sample_qubo(qubo1, seed=1)
287+
self.assertEqual(len(res.states), 1)
288+
self.assertListEqual([0, 0], list(res.states[0]))
289+
290+
# Test case 2: (0,0): -1, (1,1): -1 with expected result [1,1]
291+
qubo2 = {(0, 0): -1, (1, 1): -1}
292+
293+
sampler = oj.SASampler()
294+
res = sampler.sample_qubo(qubo2, seed=1)
295+
self.assertEqual(len(res.states), 1)
296+
self.assertListEqual([1, 1], list(res.states[0]))
297+
298+
def test_simple_two_variable_cases_with_small_interaction(self):
299+
# Test case 1: (0,0): 1, (1,1): 1 with expected result [0,0]
300+
qubo1 = {(0, 0): 1, (1, 1): 1, (0, 1): 1e-7}
301+
302+
sampler = oj.SASampler()
303+
res = sampler.sample_qubo(qubo1, seed=1)
304+
self.assertEqual(len(res.states), 1)
305+
self.assertListEqual([0, 0], list(res.states[0]))
306+
307+
# Test case 2: (0,0): -1, (1,1): -1 with expected result [1,1]
308+
qubo2 = {(0, 0): -1, (1, 1): -1, (0, 1): 1e-7}
309+
310+
sampler = oj.SASampler()
311+
res = sampler.sample_qubo(qubo2, seed=1)
312+
self.assertEqual(len(res.states), 1)
313+
self.assertListEqual([1, 1], list(res.states[0]))
314+
315+
def test_qubo_with_null_interaction(self):
316+
# Test case with zero interaction
317+
qubo = {}
318+
319+
sampler = oj.SASampler()
320+
_ = sampler.sample_qubo(qubo, seed=1)
321+
322+
def test_qubo_with_zero_interaction(self):
323+
# Test case with zero interaction
324+
qubo = {(0, 0): 0, (1, 1): 0}
325+
326+
sampler = oj.SASampler()
327+
_ = sampler.sample_qubo(qubo, seed=1)
328+
281329

282330
if __name__ == '__main__':
283331
unittest.main()

0 commit comments

Comments
 (0)