Skip to content

Commit 45cdb31

Browse files
authored
Merge pull request #164 from Dooders/dev
Add Q-learning state initialization and new README documentation
2 parents 2c54f71 + 8843fb4 commit 45cdb31

File tree

7 files changed

+366
-0
lines changed

7 files changed

+366
-0
lines changed

agents/README.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Agents Module
2+
3+
This module provides a variety of agent classes for use in reinforcement learning and maze navigation environments. Agents can be used as-is or extended for custom behaviors. Many agents have both standard and memory-augmented variants that leverage episodic and semantic memory for improved performance.
4+
5+
## Agent Types
6+
7+
### 1. `Agent` (Abstract Base Class)
8+
Defines the interface for all agents. To implement a custom agent, inherit from this class and implement the required methods.
9+
10+
**API:**
11+
```python
12+
class Agent(ABC):
13+
def __init__(self, agent_id: str, action_space, **kwargs): ...
14+
@abstractmethod
15+
def act(self, observation: MazeObservation, epsilon: float = 0.1) -> int: ...
16+
@abstractmethod
17+
def set_demo_path(self, path: list[int]) -> None: ...
18+
```
19+
20+
### 2. `RandomAgent`
21+
Selects actions randomly from the action space. Useful as a baseline.
22+
23+
### 3. `MemoryRandomAgent`
24+
A random agent that also stores and retrieves state/action information from a memory system, biasing action selection toward previously successful actions.
25+
26+
### 4. `AlgoAgent`
27+
A planning agent that uses search algorithms (BFS/DFS or custom) to plan a path to the target. Good for deterministic environments.
28+
29+
### 5. `MemoryAlgoAgent`
30+
A planning agent with memory augmentation. Retrieves similar states from memory to bias planning and action selection.
31+
32+
### 6. `QAgent`
33+
Implements tabular Q-learning. Maintains a Q-table for state-action values and uses an epsilon-greedy policy.
34+
35+
### 7. `MemoryQAgent`
36+
A Q-learning agent with memory augmentation. Stores and retrieves states, actions, and interactions from memory to bias exploration and exploitation.
37+
38+
### 8. `DeepQAgent`
39+
Implements Deep Q-Learning using PyTorch. Uses a neural network to approximate Q-values and experience replay for training.
40+
41+
### 9. `MemoryDeepQAgent`
42+
A deep Q-learning agent with memory augmentation. Stores and retrieves states and interactions from memory to bias action selection and learning.
43+
44+
---
45+
46+
## Usage
47+
48+
> **Note:** Only the abstract `Agent` is exposed in `agents/__init__.py`. To use concrete agents, import them directly from their respective files:
49+
50+
```python
51+
from agents.random_agent import RandomAgent, MemoryRandomAgent
52+
from agents.algo_agent import AlgoAgent, MemoryAlgoAgent
53+
from agents.q_agent import QAgent, MemoryQAgent
54+
from agents.deep_q_agent import DeepQAgent, MemoryDeepQAgent
55+
```
56+
57+
## Example
58+
59+
```python
60+
from agents.q_agent import QAgent
61+
from memory.api.models import MazeObservation
62+
63+
agent = QAgent(agent_id="A1", action_space=4)
64+
obs = MazeObservation(position=(0,0), target=(3,3), steps=0, nearby_obstacles=[])
65+
action = agent.act(obs)
66+
```
67+
68+
## Extending Agents
69+
To create your own agent, inherit from `Agent` and implement the `act` and `set_demo_path` methods.
70+
71+
## Memory-Augmented Agents
72+
Memory-augmented agents use a `MemorySpace` object to store and retrieve states, actions, and interactions. This enables:
73+
- Retrieval of similar past states for biasing action selection
74+
- Storing successful actions/interactions for future use
75+
- Episodic and semantic memory integration
76+
77+
## Requirements
78+
- `memory` module (for memory-augmented agents)
79+
- `numpy`, `torch` (for DeepQAgent)
80+
81+
---
82+
83+
## File Overview
84+
- `base.py`: Abstract base class
85+
- `random_agent.py`: RandomAgent, MemoryRandomAgent
86+
- `algo_agent.py`: AlgoAgent, MemoryAlgoAgent
87+
- `q_agent.py`: QAgent, MemoryQAgent
88+
- `deep_q_agent.py`: DeepQAgent, MemoryDeepQAgent
89+
90+
---
91+
92+
For more details, see the docstrings in each agent class.

agents/q_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def update_q_value(
118118
self.q_table[next_state_key] = np.zeros(self.action_space)
119119

120120
# Q-learning update
121+
if state_key not in self.q_table:
122+
self.q_table[state_key] = np.zeros(self.action_space)
121123
current_q = self.q_table[state_key][action]
122124

123125
if done:

tests/agents/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

tests/agents/test_algo_agent.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
import numpy as np
3+
from unittest.mock import MagicMock
4+
5+
from agents.algo_agent import AlgoAgent, MemoryAlgoAgent
6+
from memory.api.models import MazeObservation
7+
8+
@pytest.fixture
9+
def sample_observation():
10+
return MazeObservation(
11+
position=(1, 1),
12+
target=(2, 2),
13+
nearby_obstacles=[(0, 1), (1, 0)],
14+
steps=5,
15+
)
16+
17+
18+
def test_algo_agent_bfs_path(sample_observation):
19+
agent = AlgoAgent(agent_id="test", action_space=4, search_algo="bfs")
20+
action = agent.act(sample_observation)
21+
assert 0 <= action < 4
22+
23+
24+
def test_algo_agent_dfs_path(sample_observation):
25+
agent = AlgoAgent(agent_id="test", action_space=4, search_algo="dfs")
26+
action = agent.act(sample_observation)
27+
assert 0 <= action < 4
28+
29+
30+
def test_algo_agent_demo_path(sample_observation):
31+
agent = AlgoAgent(agent_id="test", action_space=4)
32+
agent.set_demo_path([1, 2])
33+
assert agent.act(sample_observation) == 1
34+
assert agent.act(sample_observation) == 2
35+
# After demo path, should revert to planning
36+
action = agent.act(sample_observation)
37+
assert 0 <= action < 4
38+
39+
40+
def test_memory_algo_agent_act_returns_valid_action(sample_observation):
41+
agent = MemoryAlgoAgent(agent_id="test", action_space=4)
42+
agent.memory = MagicMock()
43+
agent.memory.retrieve_similar_states.return_value = []
44+
action = agent.act(sample_observation)
45+
assert 0 <= action < 4
46+
47+
48+
def test_memory_algo_agent_demo_path(sample_observation):
49+
agent = MemoryAlgoAgent(agent_id="test", action_space=4)
50+
agent.set_demo_path([3, 0])
51+
assert agent.act(sample_observation) == 3
52+
assert agent.act(sample_observation) == 0
53+
54+
55+
def test_memory_algo_agent_memory_action(sample_observation):
56+
agent = MemoryAlgoAgent(agent_id="test", action_space=4)
57+
agent.memory = MagicMock()
58+
agent.memory.retrieve_similar_states.return_value = [
59+
{"content": {"action": 1, "reward": 1}}
60+
]
61+
np_random_backup = np.random.random
62+
np.random.random = lambda: 0.5
63+
action = agent.act(sample_observation)
64+
np.random.random = np_random_backup
65+
assert action == 1

tests/agents/test_deep_q_agent.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
import numpy as np
3+
from unittest.mock import MagicMock
4+
import torch
5+
6+
from agents.deep_q_agent import DeepQAgent, MemoryDeepQAgent
7+
from memory.api.models import MazeObservation
8+
9+
@pytest.fixture
10+
def sample_observation():
11+
return MazeObservation(
12+
position=(0, 0),
13+
target=(1, 1),
14+
nearby_obstacles=[(0, 1)],
15+
steps=1,
16+
)
17+
18+
@pytest.fixture
19+
def next_observation():
20+
return MazeObservation(
21+
position=(0, 1),
22+
target=(1, 1),
23+
nearby_obstacles=[(1, 1)],
24+
steps=2,
25+
)
26+
27+
def test_deep_q_agent_epsilon_greedy_action(sample_observation):
28+
agent = DeepQAgent(agent_id="test", action_space=4)
29+
np_random_backup = np.random.random
30+
np.random.random = lambda: 0.05
31+
action = agent.act(sample_observation, epsilon=1.0)
32+
np.random.random = np_random_backup
33+
assert 0 <= action < 4
34+
35+
def test_deep_q_agent_demo_path(sample_observation):
36+
agent = DeepQAgent(agent_id="test", action_space=4)
37+
agent.set_demo_path([2, 1])
38+
assert agent.act(sample_observation) == 2
39+
assert agent.act(sample_observation) == 1
40+
41+
def test_deep_q_agent_experience_replay(sample_observation, next_observation):
42+
agent = DeepQAgent(agent_id="test", action_space=4, batch_size=1)
43+
agent.remember(sample_observation, 1, 1.0, next_observation, False)
44+
# Should not raise error
45+
agent.update()
46+
47+
def test_memory_deep_q_agent_act_returns_valid_action(sample_observation):
48+
agent = MemoryDeepQAgent(agent_id="test", action_space=4)
49+
agent.memory = MagicMock()
50+
agent.memory.retrieve_similar_states.return_value = []
51+
action = agent.act(sample_observation)
52+
assert 0 <= action < 4
53+
54+
def test_memory_deep_q_agent_demo_path(sample_observation):
55+
agent = MemoryDeepQAgent(agent_id="test", action_space=4)
56+
agent.set_demo_path([3, 0])
57+
assert agent.act(sample_observation) == 3
58+
assert agent.act(sample_observation) == 0
59+
60+
def test_memory_deep_q_agent_memory_action(sample_observation):
61+
agent = MemoryDeepQAgent(agent_id="test", action_space=4)
62+
agent.memory = MagicMock()
63+
agent.memory.retrieve_similar_states.return_value = [
64+
{"content": {"action": 2, "reward": 1}}
65+
]
66+
np_random_backup = np.random.random
67+
np.random.random = lambda: 0.5
68+
action = agent.act(sample_observation)
69+
np.random.random = np_random_backup
70+
assert action == 2

tests/agents/test_q_agent.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
import numpy as np
3+
from unittest.mock import MagicMock
4+
5+
from agents.q_agent import QAgent, MemoryQAgent
6+
from memory.api.models import MazeObservation
7+
8+
@pytest.fixture
9+
def sample_observation():
10+
return MazeObservation(
11+
position=(0, 0),
12+
target=(1, 1),
13+
nearby_obstacles=[(0, 1)],
14+
steps=1,
15+
)
16+
17+
@pytest.fixture
18+
def next_observation():
19+
return MazeObservation(
20+
position=(0, 1),
21+
target=(1, 1),
22+
nearby_obstacles=[(1, 1)],
23+
steps=2,
24+
)
25+
26+
def test_q_agent_epsilon_greedy_action(sample_observation):
27+
agent = QAgent(agent_id="test", action_space=4)
28+
# Force random action
29+
np_random_backup = np.random.random
30+
np.random.random = lambda: 0.05
31+
action = agent.act(sample_observation, epsilon=1.0)
32+
np.random.random = np_random_backup
33+
assert 0 <= action < 4
34+
35+
def test_q_agent_demo_path(sample_observation):
36+
agent = QAgent(agent_id="test", action_space=4)
37+
agent.set_demo_path([2, 1])
38+
assert agent.act(sample_observation) == 2
39+
assert agent.act(sample_observation) == 1
40+
41+
def test_q_agent_q_value_update(sample_observation, next_observation):
42+
agent = QAgent(agent_id="test", action_space=4)
43+
action = 1
44+
reward = 1.0
45+
done = False
46+
agent.update_q_value(sample_observation, action, reward, next_observation, done)
47+
state_key = agent._get_state_key(sample_observation)
48+
assert agent.q_table[state_key][action] != 0
49+
50+
def test_memory_q_agent_act_returns_valid_action(sample_observation):
51+
agent = MemoryQAgent(agent_id="test", action_space=4)
52+
agent.memory = MagicMock()
53+
agent.memory.retrieve_similar_states.return_value = []
54+
action = agent.act(sample_observation)
55+
assert 0 <= action < 4
56+
57+
def test_memory_q_agent_demo_path(sample_observation):
58+
agent = MemoryQAgent(agent_id="test", action_space=4)
59+
agent.set_demo_path([3, 0])
60+
assert agent.act(sample_observation) == 3
61+
assert agent.act(sample_observation) == 0
62+
63+
def test_memory_q_agent_memory_action(sample_observation):
64+
agent = MemoryQAgent(agent_id="test", action_space=4)
65+
agent.memory = MagicMock()
66+
agent.memory.retrieve_similar_states.return_value = [
67+
{"content": {"action": 2, "reward": 1}}
68+
]
69+
np_random_backup = np.random.random
70+
np.random.random = lambda: 0.5
71+
action = agent.act(sample_observation)
72+
np.random.random = np_random_backup
73+
assert action == 2

tests/agents/test_random_agent.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pytest
2+
import numpy as np
3+
from unittest.mock import MagicMock
4+
5+
from agents.random_agent import RandomAgent, MemoryRandomAgent
6+
from memory.api.models import MazeObservation, MazeActionSpace
7+
8+
@pytest.fixture
9+
def sample_observation():
10+
return MazeObservation(
11+
position=(1, 1),
12+
target=(2, 2),
13+
nearby_obstacles=[(0, 1), (1, 0)],
14+
steps=5,
15+
)
16+
17+
18+
def test_random_agent_act_returns_valid_action(sample_observation):
19+
agent = RandomAgent(agent_id="test", action_space=4)
20+
action = agent.act(sample_observation)
21+
assert 0 <= action < 4
22+
23+
24+
def test_random_agent_demo_path(sample_observation):
25+
agent = RandomAgent(agent_id="test", action_space=4)
26+
agent.set_demo_path([2, 3, 1])
27+
assert agent.act(sample_observation) == 2
28+
assert agent.act(sample_observation) == 3
29+
assert agent.act(sample_observation) == 1
30+
# After demo path, should revert to random
31+
action = agent.act(sample_observation)
32+
assert 0 <= action < 4
33+
34+
35+
def test_memory_random_agent_act_returns_valid_action(sample_observation, monkeypatch):
36+
agent = MemoryRandomAgent(agent_id="test", action_space=4)
37+
# Patch memory.retrieve_similar_states to return empty
38+
agent.memory = MagicMock()
39+
agent.memory.retrieve_similar_states.return_value = []
40+
action = agent.act(sample_observation)
41+
assert 0 <= action < 4
42+
43+
44+
def test_memory_random_agent_demo_path(sample_observation):
45+
agent = MemoryRandomAgent(agent_id="test", action_space=4)
46+
agent.set_demo_path([1, 0])
47+
assert agent.act(sample_observation) == 1
48+
assert agent.act(sample_observation) == 0
49+
50+
51+
def test_memory_random_agent_memory_action(sample_observation):
52+
agent = MemoryRandomAgent(agent_id="test", action_space=4)
53+
# Patch memory.retrieve_similar_states to return a memory with action 2
54+
agent.memory = MagicMock()
55+
agent.memory.retrieve_similar_states.return_value = [
56+
{"content": {"action": 2, "reward": 1}}
57+
]
58+
# Patch np.random.random to always return 0.5 (> 0.2)
59+
np_random_backup = np.random.random
60+
np.random.random = lambda: 0.5
61+
action = agent.act(sample_observation)
62+
np.random.random = np_random_backup
63+
assert action == 2

0 commit comments

Comments
 (0)