@@ -1828,8 +1828,14 @@ def forward(self, tensordict):
18281828 class PolicyWithDevice (TensorDictModuleBase ):
18291829 in_keys = ["observation" ]
18301830 out_keys = ["action" ]
1831- # receives and sends data on gpu
1832- default_device = "cuda:0" if torch .cuda .device_count () else "cpu"
1831+
1832+ def __init__ (self , default_device = None ):
1833+ super ().__init__ ()
1834+ self .default_device = (
1835+ default_device
1836+ if default_device is not None
1837+ else ("cuda:0" if torch .cuda .device_count () else "cpu" )
1838+ )
18331839
18341840 def forward (self , tensordict ):
18351841 assert tensordict .device == _make_ordinal_device (
@@ -1846,7 +1852,7 @@ def test_output_device(self, main_device, storing_device):
18461852 env_device = None
18471853 policy_device = main_device
18481854 env = self .DeviceLessEnv (main_device )
1849- policy = self .PolicyWithDevice ()
1855+ policy = self .PolicyWithDevice (main_device )
18501856 collector = SyncDataCollector (
18511857 env ,
18521858 policy ,
@@ -1887,7 +1893,7 @@ def test_output_device(self, main_device, storing_device):
18871893 env_device = None
18881894 policy_device = None
18891895 env = self .EnvWithDevice (main_device )
1890- policy = self .PolicyWithDevice ()
1896+ policy = self .PolicyWithDevice (main_device )
18911897 collector = SyncDataCollector (
18921898 env ,
18931899 policy ,
@@ -1909,7 +1915,7 @@ def test_output_device(self, main_device, storing_device):
19091915 env_device = main_device
19101916 policy_device = main_device
19111917 env = self .EnvWithDevice (main_device )
1912- policy = self .PolicyWithDevice ()
1918+ policy = self .PolicyWithDevice (main_device )
19131919 collector = SyncDataCollector (
19141920 env ,
19151921 policy ,
0 commit comments