Skip to content

Commit b4a2cea

Browse files
authored
Fix shared CNN modules in _OnnxCNNModel and _TorchCNNModel (#189)
Deep-copy CNN modules in export wrappers to prevent shared state with the original training model. Fixes #188.
1 parent 0d83b2f commit b4a2cea

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

rsl_rl/models/cnn_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __init__(self, model: CNNModel) -> None:
168168
super().__init__()
169169
self.obs_normalizer = copy.deepcopy(model.obs_normalizer)
170170
# Convert ModuleDict to ModuleList for ordered iteration
171-
self.cnns = nn.ModuleList([model.cnns[g] for g in model.obs_groups_2d])
171+
self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d])
172172
self.mlp = copy.deepcopy(model.mlp)
173173
if model.distribution is not None:
174174
self.deterministic_output = model.distribution.as_deterministic_output_module()
@@ -204,7 +204,7 @@ def __init__(self, model: CNNModel, verbose: bool) -> None:
204204
self.verbose = verbose
205205
self.obs_normalizer = copy.deepcopy(model.obs_normalizer)
206206
# Convert ModuleDict to ModuleList for ordered iteration
207-
self.cnns = nn.ModuleList([model.cnns[g] for g in model.obs_groups_2d])
207+
self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d])
208208
self.mlp = copy.deepcopy(model.mlp)
209209
if model.distribution is not None:
210210
self.deterministic_output = model.distribution.as_deterministic_output_module()

0 commit comments

Comments
 (0)