Skip to content

Commit b6fe45e

Browse files
committed
[BugFix] Fix unique ref to lambda func
ghstack-source-id: 567a2af Pull-Request: #3282
1 parent 7866d11 commit b6fe45e

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

torchrl/envs/batched_envs.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,19 @@ def _wrap_lambdas(create_env_fn):
156156
if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn):
157157
return EnvCreator(create_env_fn)
158158
if isinstance(create_env_fn, Sequence):
159-
return [
160-
EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn
161-
for fn in create_env_fn
162-
]
159+
# Reuse EnvCreator for identical function objects to preserve
160+
# _single_task detection (e.g., when [lambda_fn] * 3 is passed)
161+
wrapped = {}
162+
result = []
163+
for fn in create_env_fn:
164+
if _is_unpicklable_lambda(fn):
165+
fn_id = id(fn)
166+
if fn_id not in wrapped:
167+
wrapped[fn_id] = EnvCreator(fn)
168+
result.append(wrapped[fn_id])
169+
else:
170+
result.append(fn)
171+
return result
163172
return create_env_fn
164173

165174
if "create_env_fn" in kwargs:

0 commit comments

Comments
 (0)