|
75 | 75 | @pytest.mark.parametrize("writer", [writers.RoundRobinWriter]) |
76 | 76 | @pytest.mark.parametrize("storage", [ListStorage, LazyTensorStorage, LazyMemmapStorage]) |
77 | 77 | @pytest.mark.parametrize("size", [3, 5, 100]) |
78 | | -class TestPrototypeBuffers: |
| 78 | +class TestComposableBuffers: |
79 | 79 | def _get_rb(self, rb_type, size, sampler, writer, storage): |
80 | 80 |
|
81 | 81 | if storage is not None: |
@@ -884,6 +884,67 @@ def test_samplerwithoutrep(size, samples, drop_last): |
884 | 884 | assert not visited |
885 | 885 |
|
886 | 886 |
|
| 887 | +class TestStateDict: |
| 888 | + @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) |
| 889 | + @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) |
| 890 | + @pytest.mark.parametrize("init_out", [True, False]) |
| 891 | + def test_load_state_dict(self, storage_in, storage_out, init_out): |
| 892 | + buffer_size = 100 |
| 893 | + if storage_in == "memmap": |
| 894 | + storage_in = LazyMemmapStorage( |
| 895 | + buffer_size, |
| 896 | + device="cpu", |
| 897 | + ) |
| 898 | + elif storage_in == "tensor": |
| 899 | + storage_in = LazyTensorStorage( |
| 900 | + buffer_size, |
| 901 | + device="cpu", |
| 902 | + ) |
| 903 | + if storage_out == "memmap": |
| 904 | + storage_out = LazyMemmapStorage( |
| 905 | + buffer_size, |
| 906 | + device="cpu", |
| 907 | + ) |
| 908 | + elif storage_out == "tensor": |
| 909 | + storage_out = LazyTensorStorage( |
| 910 | + buffer_size, |
| 911 | + device="cpu", |
| 912 | + ) |
| 913 | + |
| 914 | + replay_buffer = TensorDictReplayBuffer( |
| 915 | + pin_memory=False, |
| 916 | + prefetch=3, |
| 917 | + storage=storage_in, |
| 918 | + ) |
| 919 | + # fill replay buffer with random data |
| 920 | + transition = TensorDict( |
| 921 | + { |
| 922 | + "observation": torch.ones(1, 4), |
| 923 | + "action": torch.ones(1, 2), |
| 924 | + "reward": torch.ones(1, 1), |
| 925 | + "dones": torch.ones(1, 1), |
| 926 | + "next": {"observation": torch.ones(1, 4)}, |
| 927 | + }, |
| 928 | + batch_size=1, |
| 929 | + ) |
| 930 | + for _ in range(3): |
| 931 | + replay_buffer.extend(transition) |
| 932 | + |
| 933 | + state_dict = replay_buffer.state_dict() |
| 934 | + |
| 935 | + new_replay_buffer = TensorDictReplayBuffer( |
| 936 | + pin_memory=False, |
| 937 | + prefetch=3, |
| 938 | + storage=storage_out, |
| 939 | + ) |
| 940 | + if init_out: |
| 941 | + new_replay_buffer.extend(transition) |
| 942 | + |
| 943 | + new_replay_buffer.load_state_dict(state_dict) |
| 944 | + s = new_replay_buffer.sample(3) |
| 945 | + assert (s.exclude("index") == 1).all() |
| 946 | + |
| 947 | + |
887 | 948 | if __name__ == "__main__": |
888 | 949 | args, unknown = argparse.ArgumentParser().parse_known_args() |
889 | 950 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
0 commit comments