Skip to content

Commit 3439abe

Browse files
[Environment] Melitngpot (#75)
1 parent 6296a1f commit 3439abe

File tree

79 files changed

+486
-58
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+486
-58
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
pip install dm-meltingpot

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
python-version: ["3.10"]
23+
python-version: ["3.11"]
2424

2525
steps:
2626
- uses: actions/checkout@v3
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# This workflow will install Python dependencies, run tests and lint with a single version of Python
2+
# For more information see:
3+
# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
4+
5+
6+
name: meltingpot_tests
7+
8+
on:
9+
push:
10+
branches: [ $default-branch , "main" ]
11+
pull_request:
12+
branches: [ $default-branch , "main" ]
13+
14+
permissions:
15+
contents: read
16+
17+
jobs:
18+
tests:
19+
runs-on: ubuntu-latest
20+
strategy:
21+
fail-fast: false
22+
matrix:
23+
python-version: ["3.11"]
24+
25+
steps:
26+
- uses: actions/checkout@v3
27+
- name: Set up Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v3
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
- name: Install dependencies
32+
run: |
33+
bash .github/unittest/install_dependencies_nightly.sh
34+
- name: Install meltingpot
35+
run: |
36+
bash .github/unittest/install_meltingpot.sh
37+
- name: Test with pytest
38+
run: |
39+
pytest test/test_meltingpot.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
40+
- name: Upload coverage to Codecov
41+
uses: codecov/codecov-action@v3
42+
with:
43+
fail_ci_if_error: false

.github/workflows/pettingzoo_tests.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10"]
23+
python-version: ["3.11"]
2424

2525
steps:
2626
- uses: actions/checkout@v3
@@ -37,9 +37,7 @@ jobs:
3737
- name: Test with pytest
3838
run: |
3939
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_pettingzoo.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
40-
41-
- if: matrix.python-version == '3.10'
42-
name: Upload coverage to Codecov
40+
- name: Upload coverage to Codecov
4341
uses: codecov/codecov-action@v3
4442
with:
4543
fail_ci_if_error: false

.github/workflows/smacv2_tests.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
strategy:
2222
fail-fast: false
2323
matrix:
24-
python-version: ["3.10"]
24+
python-version: ["3.11"]
2525

2626
steps:
2727
- uses: actions/checkout@v3
@@ -44,8 +44,7 @@ jobs:
4444
4545
pytest test/test_smacv2.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
4646
47-
- if: matrix.python-version == '3.10'
48-
name: Upload coverage to Codecov
47+
- name: Upload coverage to Codecov
4948
uses: codecov/codecov-action@v3
5049
with:
5150
fail_ci_if_error: false

.github/workflows/torchrl_stable_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
python-version: ["3.10"]
23+
python-version: ["3.11"]
2424
steps:
2525
- uses: actions/checkout@v3
2626
- name: Set up Python ${{ matrix.python-version }}
@@ -38,4 +38,4 @@ jobs:
3838
bash .github/unittest/install_pettingzoo.sh
3939
- name: Tests
4040
run: |
41-
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_algorithm.py test/test_models.py test/test_task.py test/test_vmas.py test/test_pettingzoo.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
41+
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_algorithm.py test/test_models.py test/test_task.py test/test_vmas.py test/test_pettingzoo.py test/test_meltingpot.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html

.github/workflows/unit_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10"]
23+
python-version: ["3.8", "3.9", "3.10","3.11"]
2424

2525
steps:
2626
- uses: actions/checkout@v3

.github/workflows/vmas_tests.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10"]
23+
python-version: ["3.11"]
2424

2525
steps:
2626
- uses: actions/checkout@v3
@@ -38,8 +38,7 @@ jobs:
3838
run: |
3939
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_vmas.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
4040
41-
- if: matrix.python-version == '3.10'
42-
name: Upload coverage to Codecov
41+
- name: Upload coverage to Codecov
4342
uses: codecov/codecov-action@v3
4443
with:
4544
fail_ci_if_error: false

README.md

Lines changed: 13 additions & 7 deletions

benchmarl/algorithms/common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pathlib
88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass
10-
from typing import Any, Dict, Iterable, Optional, Tuple, Type
10+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
1111

1212
from tensordict import TensorDictBase
1313
from tensordict.nn import TensorDictModule, TensorDictSequential
@@ -19,6 +19,7 @@
1919
TensorDictReplayBuffer,
2020
)
2121
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
22+
from torchrl.envs import Compose, Transform
2223
from torchrl.objectives import LossModule
2324
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater
2425

@@ -132,15 +133,15 @@ def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater
132133
return self._losses_and_updaters[group]
133134

134135
def get_replay_buffer(
135-
self,
136-
group: str,
136+
self, group: str, transforms: List[Transform] = None
137137
) -> ReplayBuffer:
138138
"""
139139
Get the ReplayBuffer for a specific group.
140140
This function will check ``self.on_policy`` and create the buffer accordingly
141141
142142
Args:
143143
group (str): agent group of the loss and updater
144+
transforms (optional, list of Transform): Transforms to apply to the replay buffer ``.sample()`` call
144145
145146
Returns: ReplayBuffer the group
146147
"""
@@ -154,6 +155,7 @@ def get_replay_buffer(
154155
sampler=sampler,
155156
batch_size=sampling_size,
156157
priority_key=(group, "td_error"),
158+
transform=Compose(*transforms) if transforms is not None else None,
157159
)
158160

159161
def get_policy_for_loss(self, group: str) -> TensorDictModule:

0 commit comments

Comments
 (0)