Skip to content

Commit 6d581f0

Browse files
committed
fix policy with device
1 parent 6783e0e commit 6d581f0

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

test/test_collector.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,8 +1828,14 @@ def forward(self, tensordict):
18281828
class PolicyWithDevice(TensorDictModuleBase):
18291829
in_keys = ["observation"]
18301830
out_keys = ["action"]
1831-
# receives and sends data on gpu
1832-
default_device = "cuda:0" if torch.cuda.device_count() else "cpu"
1831+
1832+
def __init__(self, default_device=None):
1833+
super().__init__()
1834+
self.default_device = (
1835+
default_device
1836+
if default_device is not None
1837+
else ("cuda:0" if torch.cuda.device_count() else "cpu")
1838+
)
18331839

18341840
def forward(self, tensordict):
18351841
assert tensordict.device == _make_ordinal_device(
@@ -1846,7 +1852,7 @@ def test_output_device(self, main_device, storing_device):
18461852
env_device = None
18471853
policy_device = main_device
18481854
env = self.DeviceLessEnv(main_device)
1849-
policy = self.PolicyWithDevice()
1855+
policy = self.PolicyWithDevice(main_device)
18501856
collector = SyncDataCollector(
18511857
env,
18521858
policy,
@@ -1887,7 +1893,7 @@ def test_output_device(self, main_device, storing_device):
18871893
env_device = None
18881894
policy_device = None
18891895
env = self.EnvWithDevice(main_device)
1890-
policy = self.PolicyWithDevice()
1896+
policy = self.PolicyWithDevice(main_device)
18911897
collector = SyncDataCollector(
18921898
env,
18931899
policy,
@@ -1909,7 +1915,7 @@ def test_output_device(self, main_device, storing_device):
19091915
env_device = main_device
19101916
policy_device = main_device
19111917
env = self.EnvWithDevice(main_device)
1912-
policy = self.PolicyWithDevice()
1918+
policy = self.PolicyWithDevice(main_device)
19131919
collector = SyncDataCollector(
19141920
env,
19151921
policy,

0 commit comments

Comments
 (0)