|
35 | 35 | ) |
36 | 36 | from tensordict.base import NO_DEFAULT |
37 | 37 | from tensordict.nn import CudaGraphModule, TensorDictModule |
| 38 | +from tensordict.utils import Buffer |
38 | 39 | from torch import multiprocessing as mp |
| 40 | +from torch.nn import Parameter |
39 | 41 | from torch.utils.data import IterableDataset |
40 | 42 |
|
41 | 43 | from torchrl._utils import ( |
@@ -202,17 +204,17 @@ def map_weight( |
202 | 204 | policy_device=policy_device, |
203 | 205 | ): |
204 | 206 |
|
205 | | - is_param = isinstance(weight, nn.Parameter) |
206 | | - is_buffer = isinstance(weight, nn.Buffer) |
| 207 | + is_param = isinstance(weight, Parameter) |
| 208 | + is_buffer = isinstance(weight, Buffer) |
207 | 209 | weight = weight.data |
208 | 210 | if weight.device != policy_device: |
209 | 211 | weight = weight.to(policy_device) |
210 | 212 | elif weight.device.type in ("cpu", "mps"): |
211 | 213 | weight = weight.share_memory_() |
212 | 214 | if is_param: |
213 | | - weight = nn.Parameter(weight, requires_grad=False) |
| 215 | + weight = Parameter(weight, requires_grad=False) |
214 | 216 | elif is_buffer: |
215 | | - weight = nn.Buffer(weight) |
| 217 | + weight = Buffer(weight) |
216 | 218 | return weight |
217 | 219 |
|
218 | 220 | # Create a stateless policy, then populate this copy with params on device |
@@ -3089,12 +3091,12 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): |
3089 | 3091 |
|
3090 | 3092 |
|
3091 | 3093 | def _make_meta_params(param): |
3092 | | - is_param = isinstance(param, nn.Parameter) |
| 3094 | + is_param = isinstance(param, Parameter) |
3093 | 3095 |
|
3094 | 3096 | pd = param.detach().to("meta") |
3095 | 3097 |
|
3096 | 3098 | if is_param: |
3097 | | - pd = nn.Parameter(pd, requires_grad=False) |
| 3099 | + pd = Parameter(pd, requires_grad=False) |
3098 | 3100 | return pd |
3099 | 3101 |
|
3100 | 3102 |
|
|
0 commit comments