Skip to content

Commit 92abcb5

Browse files
committed
Enhance project structure and documentation
- Move language selection to top of README for better visibility - Add comprehensive API reference and contributing guidelines - Create GitHub Actions CI workflow for automated testing - Add CONTRIBUTING.md with development guidelines - Include example.py for easy getting started - Update .gitignore for better development file management - Set up development dependencies and code quality tools
1 parent c33b71b commit 92abcb5

File tree

5 files changed

+253
-3
lines changed

5 files changed

+253
-3
lines changed

.github/workflows/ci.yml

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: [3.8, 3.9, "3.10", "3.11"]
15+
16+
steps:
17+
- uses: actions/checkout@v4
18+
19+
- name: Set up Python ${{ matrix.python-version }}
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install -e .[dev]
28+
29+
- name: Lint with flake8
30+
run: |
31+
# Stop the build if there are Python syntax errors or undefined names
32+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
33+
# Exit-zero treats all errors as warnings
34+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=88 --statistics
35+
36+
- name: Check code formatting with black
37+
run: |
38+
black --check .
39+
40+
- name: Type checking with mypy
41+
run: |
42+
mypy . --ignore-missing-imports
43+
44+
- name: Test installation
45+
run: |
46+
python -c "from sb3_grpo import GRPO; print('Import successful!')"
47+
48+
- name: Run basic functionality test
49+
run: |
50+
python -c "
51+
import gymnasium as gym
52+
import torch
53+
from stable_baselines3.common.vec_env import DummyVecEnv
54+
from sb3_grpo import GRPO
55+
56+
def simple_reward(state, action, next_state):
57+
return torch.ones(state.shape[0], 1)
58+
59+
env = DummyVecEnv([lambda: gym.make('CartPole-v1')])
60+
agent = GRPO('MlpPolicy', env, reward_function=simple_reward, verbose=0)
61+
agent.learn(total_timesteps=100)
62+
print('Basic functionality test passed!')
63+
"

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ grpo_アクションサンプル用メモリサイズ.md
33
grpoの流れ.md
44
README_ja.md
55

6+
# Example outputs and temporary files
7+
example_outputs/
8+
temp/
9+
*.tmp
10+
611
# Python
712
__pycache__/
813
*.py[cod]

CONTRIBUTING.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Contributing to SB3-GRPO
2+
3+
Thank you for your interest in contributing to SB3-GRPO!
4+
5+
## Development Setup
6+
7+
1. Fork the repository
8+
2. Clone your fork:
9+
```bash
10+
git clone https://github.com/yourusername/sb3-grpo.git
11+
cd sb3-grpo
12+
```
13+
3. Install in development mode:
14+
```bash
15+
pip install -e .[dev]
16+
```
17+
18+
## Code Style
19+
20+
We use the following tools to maintain code quality:
21+
22+
- **Black**: Code formatting
23+
- **Flake8**: Linting
24+
- **MyPy**: Type checking
25+
26+
Run these before submitting:
27+
28+
```bash
29+
black .
30+
flake8 .
31+
mypy . --ignore-missing-imports
32+
```
33+
34+
## Testing
35+
36+
Make sure your changes don't break existing functionality:
37+
38+
```bash
39+
python -c "from sb3_grpo import GRPO; print('Import test passed!')"
40+
```
41+
42+
## Pull Request Process
43+
44+
1. Create a feature branch from `main`
45+
2. Make your changes
46+
3. Run the code quality tools
47+
4. Test your changes
48+
5. Submit a pull request with a clear description
49+
50+
## Reporting Issues
51+
52+
When reporting bugs, please include:
53+
54+
- Python version
55+
- PyTorch version
56+
- Stable Baselines3 version
57+
- Complete error traceback
58+
- Minimal reproduction example
59+
60+
## Feature Requests
61+
62+
We welcome feature requests! Please open an issue with:
63+
64+
- Clear description of the feature
65+
- Use case examples
66+
- Proposed API (if applicable)
67+
68+
Thank you for contributing!

README.md

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55

66
[[License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
77

8+
## Language Versions / 言語選択
9+
10+
- **English**: [README.md](README.md) (this file)
11+
- **日本語**: [README_ja.md](README_ja.md)
12+
13+
---
14+
815
`sb3-grpo` is a [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) (SB3) compatible implementation of **Group Relative Policy Optimization (GRPO)**.
916

1017
This algorithm can be used as a drop-in replacement for standard PPO, providing stable learning especially in environments where rewards can be densely defined for states and actions.
@@ -159,10 +166,48 @@ python example.py
159166

160167
As training progresses, standard SB3 logs will be displayed. If the agent can maintain CartPole upright for extended periods after training, it's successful.
161168

162-
## Language Versions
169+
## API Reference
163170

164-
- **English**: [README.md](README.md) (this file)
165-
- **日本語**: [README_ja.md](README_ja.md)
171+
### GRPO Class
172+
173+
```python
174+
class GRPO(PPO):
175+
"""
176+
Group Relative Policy Optimization (GRPO) implementation extending PPO.
177+
178+
Args:
179+
policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
180+
env: The environment to learn from
181+
reward_function: Function to calculate rewards from (state, action, next_state)
182+
**kwargs: Other standard PPO arguments (learning_rate, n_steps, etc.)
183+
"""
184+
```
185+
186+
### Reward Function Interface
187+
188+
Your reward function must follow this signature:
189+
190+
```python
191+
def your_reward_function(
192+
state: torch.Tensor, # Current state [batch_size, state_dim]
193+
action: torch.Tensor, # Action taken [batch_size, 1]
194+
next_state: torch.Tensor # Resulting state [batch_size, state_dim]
195+
) -> torch.Tensor: # Returns: rewards [batch_size, 1]
196+
# Your reward calculation logic here
197+
return rewards
198+
```
199+
200+
## Contributing
201+
202+
Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
203+
204+
### Development Setup
205+
206+
```bash
207+
git clone https://github.com/kechirojp/sb3-grpo.git
208+
cd sb3-grpo
209+
pip install -e .[dev] # Install with development dependencies
210+
```
166211

167212
## License
168213

example.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# example.py
2+
3+
import gymnasium as gym
4+
import torch
5+
from stable_baselines3.common.vec_env import DummyVecEnv
6+
7+
# Import GRPO from the package
8+
from sb3_grpo import GRPO
9+
10+
# --- 1. Define reward function for GRPO ---
11+
# The core of GRPO is the ability to inject custom reward functions.
12+
# Here we define a function that evaluates how "good" the next state is.
13+
def cartpole_reward_fn(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor) -> torch.Tensor:
14+
"""
15+
Reward function for CartPole environment.
16+
Evaluates how "good" the next_state is.
17+
- Higher reward for pole angle closer to vertical
18+
- Higher reward for cart position closer to center
19+
"""
20+
# next_state contents: [cart_pos, cart_vel, pole_angle, pole_vel]
21+
cart_pos = next_state[:, 0]
22+
pole_angle = next_state[:, 2]
23+
24+
# Reward is higher when angle and position are closer to 0
25+
reward = 1.0 - torch.abs(pole_angle) - 0.1 * torch.abs(cart_pos)
26+
27+
return reward.unsqueeze(-1)
28+
29+
30+
# --- 2. Environment setup ---
31+
# Standard Stable Baselines3 environment preparation
32+
env = gym.make("CartPole-v1")
33+
env = DummyVecEnv([lambda: env])
34+
35+
36+
# --- 3. Create GRPO agent ---
37+
# Usage is almost identical to PPO instantiation.
38+
agent = GRPO(
39+
"MlpPolicy",
40+
env,
41+
reward_function=cartpole_reward_fn, # Inject reward function here
42+
n_steps=256,
43+
batch_size=64,
44+
n_epochs=10,
45+
learning_rate=3e-4,
46+
verbose=1,
47+
)
48+
49+
# --- 4. Training ---
50+
# Just call the `learn` method like standard SB3 PPO
51+
print("--- Starting GRPO Training ---")
52+
agent.learn(total_timesteps=20000)
53+
print("--- Training Finished ---")
54+
55+
56+
# --- 5. Evaluate trained agent ---
57+
print("\n--- Evaluating Trained Agent ---")
58+
eval_env = gym.make("CartPole-v1")
59+
obs, _ = eval_env.reset()
60+
total_reward = 0
61+
for _ in range(1000):
62+
action, _ = agent.predict(obs, deterministic=True)
63+
obs, reward, terminated, truncated, info = eval_env.step(action)
64+
total_reward += reward
65+
if terminated or truncated:
66+
print(f"Episode finished with total reward: {total_reward}")
67+
total_reward = 0
68+
obs, _ = eval_env.reset()
69+
eval_env.close()

0 commit comments

Comments
 (0)