Skip to content

Commit 4abac1b

Browse files
committed
Adjustments to env class name handling in ZooToGymMultiAgentAdapter
1 parent 14ec065 commit 4abac1b

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

aintelope/agents/sb3_base_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def train(self, num_total_steps):
655655
) # need to resolve the conf before passing to subprocesses since OmegaConf resolvers do not seem to work well in subprocesses
656656

657657
env_wrapper = MultiAgentZooToGymAdapterZooSide(
658-
self.env, self.env_classname, self.cfg
658+
self.env, self.cfg, self.env_classname
659659
)
660660
self.models, self.exceptions = env_wrapper.train(
661661
num_total_steps=num_total_steps,

tests/environments/test_savanna_safetygrid_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,14 @@ def sb3_gym_test_thread_entry_point(
248248
gpu_index,
249249
num_total_steps,
250250
model_constructor,
251+
env_classname,
251252
agent_id,
252253
checkpoint_filename,
253254
cfg,
254255
observation_space,
255256
action_space,
257+
*args,
258+
**kwargs,
256259
):
257260
env_wrapper = MultiAgentZooToGymAdapterGymSide(
258261
pipe, agent_id, checkpoint_filename, observation_space, action_space
@@ -291,7 +294,7 @@ def test_multiagent_zoo_to_gym_wrapper_scalarized_rewards(execution_number):
291294
env.seed(execution_number)
292295

293296
env_wrapper = MultiAgentZooToGymAdapterZooSide(
294-
env, cfg=None
297+
env, cfg=None, env_classname=None
295298
) # cfg is unused at sb3_gym_test_thread_entry_point() function
296299
_, exceptions = env_wrapper.train(
297300
num_total_steps=None, # unused at sb3_gym_test_thread_entry_point() function

tests/environments/test_savanna_safetygrid_sequential.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,14 @@ def sb3_gym_test_thread_entry_point(
233233
gpu_index,
234234
num_total_steps,
235235
model_constructor,
236+
env_classname,
236237
agent_id,
237238
checkpoint_filename,
238239
cfg,
239240
observation_space,
240241
action_space,
242+
*args,
243+
**kwargs,
241244
):
242245
env_wrapper = MultiAgentZooToGymAdapterGymSide(
243246
pipe, agent_id, checkpoint_filename, observation_space, action_space
@@ -280,7 +283,7 @@ def sb3_gym_test_thread_entry_point(
280283
# env.seed(execution_number)
281284

282285
# env_wrapper = MultiAgentZooToGymAdapterZooSide(
283-
# env, cfg=None
286+
# env, cfg=None, env_classname=None
284287
# ) # cfg is unused at sb3_gym_test_thread_entry_point() function
285288
# _, exceptions = env_wrapper.train(
286289
# num_total_steps=None, # unused at sb3_gym_test_thread_entry_point() function

0 commit comments

Comments
 (0)