diff --git a/libero/libero/benchmark/__init__.py b/libero/libero/benchmark/__init__.py index 2b833d09..69b29ed4 100644 --- a/libero/libero/benchmark/__init__.py +++ b/libero/libero/benchmark/__init__.py @@ -161,7 +161,10 @@ def get_task_init_states(self, i): self.tasks[i].problem_folder, self.tasks[i].init_states_file, ) - init_states = torch.load(init_states_path) + if torch.__version__ >= "2.2.0": + init_states = torch.load(init_states_path, weights_only=False) + else: + init_states = torch.load(init_states_path) return init_states def set_task_embs(self, task_embs): diff --git a/libero/libero/envs/env_wrapper.py b/libero/libero/envs/env_wrapper.py index b5a732b0..99447e2a 100644 --- a/libero/libero/envs/env_wrapper.py +++ b/libero/libero/envs/env_wrapper.py @@ -38,13 +38,16 @@ def __init__( camera_segmentations=None, renderer="mujoco", renderer_config=None, + controller_configs=None, **kwargs, ): assert os.path.exists( bddl_file_name ), f"[error] {bddl_file_name} does not exist!" - controller_configs = suite.load_controller_config(default_controller=controller) + if controller_configs is None: + # Load default controller configs from robosuite + controller_configs = suite.load_controller_config(default_controller=controller) problem_info = BDDLUtils.get_problem_info(bddl_file_name) # Check if we're using a multi-armed environment and use env_configuration argument if so