Skip to content

Commit 70b61d8

Browse files
committed
[Tests] Add tests for RL utils and reward observers
1 parent 5fe4e91 commit 70b61d8

File tree

4 files changed

+287
-3
lines changed

4 files changed

+287
-3
lines changed

job_shop_lib/reinforcement_learning/_reward_observers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ class RewardWithPenalties(RewardObserver):
125125
The following functions (along with ``functools.partial``) can be
126126
used to create penalty functions:
127127
128-
- :class:`job_shop_lib.metaheuristics.penalty_for_deadlines`
129-
- :class:`job_shop_lib.metaheuristics.penalty_for_due_dates`
128+
- :class:`~job_shop_lib.reinforcement_learning.get_deadline_violation_penalty`
129+
- :class:`~job_shop_lib.reinforcement_learning.get_due_date_violation_penalty`
130130
131-
"""
131+
""" # noqa: E501
132132

133133
def __init__(
134134
self,

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,17 @@ def ft06_instance():
334334
@pytest.fixture
335335
def seeded_rng() -> random.Random:
336336
return random.Random(42)
337+
338+
339+
@pytest.fixture
340+
def single_machine_instance() -> JobShopInstance:
341+
# Two single-op jobs on same machine
342+
jobs = [[Operation(0, 2)], [Operation(0, 3)]]
343+
return JobShopInstance(jobs, name="SingleMachine")
344+
345+
346+
@pytest.fixture
347+
def two_machines_instance() -> JobShopInstance:
348+
# Two jobs, each with one operation on different machines
349+
jobs = [[Operation(0, 5)], [Operation(1, 3)]]
350+
return JobShopInstance(jobs, name="TwoMachines")
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# pylint: disable=missing-function-docstring, redefined-outer-name
2+
import functools
3+
import pytest
4+
5+
from job_shop_lib import JobShopInstance, Operation
6+
from job_shop_lib.dispatching import Dispatcher
7+
from job_shop_lib.exceptions import ValidationError
8+
from job_shop_lib.reinforcement_learning import (
9+
MakespanReward,
10+
IdleTimeReward,
11+
RewardWithPenalties,
12+
get_due_date_violation_penalty,
13+
get_deadline_violation_penalty,
14+
)
15+
16+
17+
def test_makespan_reward_basic(single_machine_instance: JobShopInstance):
18+
dispatcher = Dispatcher(single_machine_instance)
19+
reward_obs = MakespanReward(dispatcher)
20+
21+
# Schedule first job on machine 0
22+
op0 = single_machine_instance.jobs[0][0]
23+
dispatcher.dispatch(op0, 0)
24+
assert reward_obs.rewards[-1] == -2
25+
26+
# Schedule second job on same machine
27+
op1 = single_machine_instance.jobs[1][0]
28+
dispatcher.dispatch(op1, 0)
29+
# makespan increases from 2 to 5
30+
assert reward_obs.rewards[-1] == -3
31+
32+
# Sum of rewards equals -final_makespan
33+
assert sum(reward_obs.rewards) == -dispatcher.schedule.makespan() == -5
34+
35+
36+
def test_makespan_reward_zero_when_no_increase(
37+
two_machines_instance: JobShopInstance,
38+
):
39+
dispatcher = Dispatcher(two_machines_instance)
40+
reward_obs = MakespanReward(dispatcher)
41+
42+
# Schedule the longer op first -> makespan = 5
43+
op_long = two_machines_instance.jobs[0][0]
44+
dispatcher.dispatch(op_long, 0)
45+
assert reward_obs.rewards[-1] == -5
46+
47+
# Now schedule the shorter op on another machine -> ends
48+
# at 3 < current makespan
49+
op_short = two_machines_instance.jobs[1][0]
50+
dispatcher.dispatch(op_short, 1)
51+
# No makespan increase -> zero reward
52+
assert reward_obs.rewards[-1] == 0
53+
54+
55+
def test_idle_time_reward_computation():
56+
# Construct instance that creates idle time on machine 0
57+
# Job1: M0(1) then M1(1)
58+
# Job0: M1(5) then M0(1) -> causes M0 idle from t=1 to t=5
59+
jobs = [
60+
[Operation(1, 5), Operation(0, 1)], # job 0
61+
[Operation(0, 1), Operation(1, 1)], # job 1
62+
]
63+
instance = JobShopInstance(jobs, name="IdleTimeExample")
64+
dispatcher = Dispatcher(instance)
65+
idle_obs = IdleTimeReward(dispatcher)
66+
67+
# 1) j1[0] on M0 at t=0..1
68+
dispatcher.dispatch(instance.jobs[1][0], 0)
69+
assert idle_obs.rewards[-1] == 0 # first op on machine -> start_time 0
70+
71+
# 2) j0[0] on M1 at t=0..5
72+
dispatcher.dispatch(instance.jobs[0][0], 1)
73+
assert idle_obs.rewards[-1] == 0 # first op on machine -> start_time 0
74+
75+
# 3) j1[1] on M1 at t=5..6 (no idle on M1)
76+
dispatcher.dispatch(instance.jobs[1][1], 1)
77+
assert idle_obs.rewards[-1] == 0
78+
79+
# 4) j0[1] on M0 at t=5..6 (idle on M0 from 1 to 5 -> reward = -4)
80+
dispatcher.dispatch(instance.jobs[0][1], 0)
81+
assert idle_obs.rewards[-1] == -4
82+
83+
84+
def test_reward_with_penalties_due_date():
85+
# Build small instance where second op violates due date
86+
jobs = [
87+
[Operation(0, 1)],
88+
[
89+
Operation(0, 10, due_date=5)
90+
], # will start at 1 and end at 11 -> late
91+
]
92+
instance = JobShopInstance(jobs, name="DueDatePenalty")
93+
dispatcher = Dispatcher(instance)
94+
95+
base = MakespanReward(dispatcher)
96+
penalty_fn = functools.partial(
97+
get_due_date_violation_penalty, due_date_penalty_factor=7
98+
)
99+
reward = RewardWithPenalties(
100+
dispatcher,
101+
base_reward_observer=base,
102+
penalty_function=penalty_fn,
103+
)
104+
105+
# First op (no penalty)
106+
dispatcher.dispatch(instance.jobs[0][0], 0)
107+
assert base.rewards[-1] == -1
108+
assert reward.rewards[-1] == -1
109+
110+
# Second op violates due date -> penalty 7
111+
dispatcher.dispatch(instance.jobs[1][0], 0)
112+
assert base.rewards[-1] == -10
113+
assert reward.rewards[-1] == -10 - 7
114+
115+
116+
def test_reward_with_penalties_deadline():
117+
jobs = [
118+
[Operation(0, 1)],
119+
[Operation(0, 10, deadline=5)], # ends at 11 -> deadline violation
120+
]
121+
instance = JobShopInstance(jobs, name="DeadlinePenalty")
122+
dispatcher = Dispatcher(instance)
123+
124+
base = MakespanReward(dispatcher)
125+
penalty_fn = functools.partial(
126+
get_deadline_violation_penalty, deadline_penalty_factor=13
127+
)
128+
reward = RewardWithPenalties(
129+
dispatcher,
130+
base_reward_observer=base,
131+
penalty_function=penalty_fn,
132+
)
133+
134+
dispatcher.dispatch(instance.jobs[0][0], 0)
135+
dispatcher.dispatch(instance.jobs[1][0], 0)
136+
assert reward.rewards[-1] == -10 - 13
137+
138+
139+
def test_reward_with_penalties_requires_same_dispatcher():
140+
instance = JobShopInstance([[Operation(0, 1)]])
141+
d1 = Dispatcher(instance)
142+
d2 = Dispatcher(instance)
143+
base = MakespanReward(d1)
144+
145+
with pytest.raises(ValidationError):
146+
RewardWithPenalties(
147+
d2, base_reward_observer=base, penalty_function=lambda op, d: 0.0
148+
)
149+
150+
151+
def test_reward_with_penalties_unsubscribes_base():
152+
instance = JobShopInstance([[Operation(0, 1)], [Operation(0, 1)]])
153+
dispatcher = Dispatcher(instance)
154+
155+
base = MakespanReward(dispatcher)
156+
assert base in dispatcher.subscribers
157+
158+
reward = RewardWithPenalties(
159+
dispatcher,
160+
base_reward_observer=base,
161+
penalty_function=lambda op, d: 0.0,
162+
)
163+
# Base should be unsubscribed; wrapper is subscribed
164+
assert base not in dispatcher.subscribers
165+
assert reward in dispatcher.subscribers
166+
167+
# test reset
168+
reward.reset()
169+
assert not reward.rewards
170+
assert not base.rewards
171+
172+
173+
def test_reward_observers_reset():
174+
instance = JobShopInstance([[Operation(0, 1)], [Operation(0, 1)]])
175+
dispatcher = Dispatcher(instance)
176+
177+
m_reward = MakespanReward(dispatcher)
178+
i_reward = IdleTimeReward(dispatcher)
179+
180+
dispatcher.dispatch(instance.jobs[0][0], 0)
181+
dispatcher.dispatch(instance.jobs[1][0], 0)
182+
183+
# Ensure rewards collected
184+
assert m_reward.rewards
185+
assert i_reward.rewards
186+
187+
# Reset and ensure cleared and internal state matches
188+
m_reward.reset()
189+
i_reward.reset()
190+
assert not m_reward.rewards
191+
assert not i_reward.rewards

tests/reinforcement_learning/test_utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
add_padding,
1010
create_edge_type_dict,
1111
map_values,
12+
get_optimal_actions,
1213
get_deadline_violation_penalty,
1314
get_due_date_violation_penalty,
1415
)
16+
from job_shop_lib.dispatching import OptimalOperationsObserver
17+
from job_shop_lib.dispatching.rules import DispatchingRuleSolver
1518

1619

1720
def test_add_padding_int_array():
@@ -401,5 +404,81 @@ def test_due_date_penalty_custom_factor():
401404
)
402405

403406

407+
# ---------------- get_optimal_actions tests ---------------- #
408+
409+
410+
def test_get_optimal_actions_initial_and_after_step(
411+
example_job_shop_instance: JobShopInstance,
412+
):
413+
# Build a reference schedule using a simple heuristic solver
414+
solver = DispatchingRuleSolver()
415+
reference_schedule = solver.solve(example_job_shop_instance)
416+
417+
# Fresh dispatcher and observer on same instance
418+
dispatcher = Dispatcher(example_job_shop_instance)
419+
optimal_obs = OptimalOperationsObserver(dispatcher, reference_schedule)
420+
421+
# Build available actions tuples (operation_id, machine_id, job_id)
422+
available_ops = dispatcher.available_operations()
423+
actions = [
424+
(op.operation_id, op.machine_id, op.job_id) for op in available_ops
425+
]
426+
427+
# Compute mapping and expected optimal ids
428+
mapping = get_optimal_actions(optimal_obs, actions)
429+
expected_ones = {
430+
(op.operation_id, op.machine_id, op.job_id)
431+
for op in optimal_obs.optimal_available
432+
}
433+
434+
# Check 1 for optimal, 0 otherwise
435+
for a in actions:
436+
assert mapping[a] == int(a in expected_ones)
437+
438+
# Dispatch one optimal operation and validate mapping updates
439+
op_to_dispatch = next(iter(optimal_obs.optimal_available))
440+
dispatcher.dispatch(op_to_dispatch)
441+
442+
available_ops = dispatcher.available_operations()
443+
actions = [
444+
(op.operation_id, op.machine_id, op.job_id) for op in available_ops
445+
]
446+
mapping = get_optimal_actions(optimal_obs, actions)
447+
expected_ones = {
448+
(op.operation_id, op.machine_id, op.job_id)
449+
for op in optimal_obs.optimal_available
450+
}
451+
for a in actions:
452+
assert mapping[a] == int(a in expected_ones)
453+
454+
455+
def test_get_optimal_actions_marks_non_optimal_zero(
456+
example_job_shop_instance: JobShopInstance,
457+
):
458+
solver = DispatchingRuleSolver()
459+
reference_schedule = solver.solve(example_job_shop_instance)
460+
dispatcher = Dispatcher(example_job_shop_instance)
461+
optimal_obs = OptimalOperationsObserver(dispatcher, reference_schedule)
462+
463+
# Valid available actions
464+
available_ops = dispatcher.available_operations()
465+
actions = [
466+
(op.operation_id, op.machine_id, op.job_id) for op in available_ops
467+
]
468+
469+
# Add an artificial non-optimal action tuple (invalid machine id)
470+
if actions:
471+
fake_action = (actions[0][0], actions[0][1] + 99, actions[0][2])
472+
actions_with_fake = actions + [fake_action]
473+
else:
474+
actions_with_fake = []
475+
476+
mapping = get_optimal_actions(optimal_obs, actions_with_fake)
477+
478+
# Fake action should be marked as non-optimal (0)
479+
if actions_with_fake:
480+
assert mapping[fake_action] == 0
481+
482+
404483
if __name__ == "__main__":
405484
pytest.main(["-vv", __file__])

0 commit comments

Comments
 (0)