1010import numpy as np
1111import pytest
1212import torch
13+ from _utils_internal import get_available_devices
1314from torchrl .data .tensordict .memmap import MemmapTensor
1415
1516
@@ -35,7 +36,18 @@ def test_grad():
3536 MemmapTensor (t + 1 )
3637
3738
38- @pytest .mark .parametrize ("dtype" , [torch .float , torch .int , torch .double , torch .bool ])
39+ @pytest .mark .parametrize (
40+ "dtype" ,
41+ [
42+ torch .half ,
43+ torch .float ,
44+ torch .double ,
45+ torch .int ,
46+ torch .uint8 ,
47+ torch .long ,
48+ torch .bool ,
49+ ],
50+ )
3951@pytest .mark .parametrize (
4052 "shape" ,
4153 [
@@ -45,8 +57,9 @@ def test_grad():
4557 [1 , 2 ],
4658 ],
4759)
48- def test_memmap_metadata (dtype , shape ):
49- t = torch .tensor ([1 , 0 ]).reshape (shape )
60+ def test_memmap_data_type (dtype , shape ):
61+ """Test that MemmapTensor can be created with a given data type and shape."""
62+ t = torch .tensor ([1 , 0 ], dtype = dtype ).reshape (shape )
5063 m = MemmapTensor (t )
5164 assert m .dtype == t .dtype
5265 assert (m == t ).all ()
@@ -137,9 +150,49 @@ def test_memmap_clone():
137150 assert m2c == m1
138151
139152
140- def test_memmap_tensor ():
141- t = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
142- assert (torch .tensor (t ) == t ).all ()
153+ @pytest .mark .parametrize ("device" , get_available_devices ())
154+ def test_memmap_same_device_as_tensor (device ):
155+ """
156+ Created MemmapTensor should be on the same device as the input tensor.
157+ Check if device is correct when .to(device) is called.
158+ """
159+ t = torch .tensor ([1 ], device = device )
160+ m = MemmapTensor (t )
161+ assert m .device == torch .device (device )
162+ for other_device in get_available_devices ():
163+ if other_device != device :
164+ with pytest .raises (
165+ RuntimeError ,
166+ match = "Expected all tensors to be on the same device, "
167+ + "but found at least two devices" ,
168+ ):
169+ assert torch .all (m + torch .ones ([3 , 4 ], device = other_device ) == 1 )
170+ m = m .to (other_device )
171+ assert m .device == torch .device (other_device )
172+
173+
174+ @pytest .mark .parametrize ("device" , get_available_devices ())
175+ def test_memmap_create_on_same_device (device ):
176+ """Test if the device arg for MemmapTensor init is respected."""
177+ m = MemmapTensor ([3 , 4 ], device = device )
178+ assert m .device == torch .device (device )
179+
180+
181+ @pytest .mark .parametrize ("device" , get_available_devices ())
182+ @pytest .mark .parametrize (
183+ "value" , [torch .zeros ([3 , 4 ]), MemmapTensor (torch .zeros ([3 , 4 ]))]
184+ )
185+ @pytest .mark .parametrize ("shape" , [[3 , 4 ], [[3 , 4 ]]])
186+ def test_memmap_zero_value (device , value , shape ):
187+ """
188+ Test if all entries are zeros when MemmapTensor is created with size.
189+ """
190+ value = value .to (device )
191+ expected_memmap_tensor = MemmapTensor (value )
192+ m = MemmapTensor (* shape , device = device )
193+ assert m .shape == (3 , 4 )
194+ assert torch .all (m == expected_memmap_tensor )
195+ assert torch .all (m + torch .ones ([3 , 4 ], device = device ) == 1 )
143196
144197
145198if __name__ == "__main__" :
0 commit comments