Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 12 additions & 34 deletions libero/libero/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import warnings
import yaml
from filelock import FileLock

# This is a default path for localizing all the benchmark related files
libero_config_path = os.environ.get(
"LIBERO_CONFIG_PATH", os.path.expanduser("~/.libero")
)
config_file = os.path.join(libero_config_path, "config.yaml")
config_file_lock = os.path.join(libero_config_path, ".config_lock")


def get_default_path_dict(custom_location=None):
Expand Down Expand Up @@ -39,11 +42,6 @@ def get_libero_path(query_key):
with open(config_file, "r") as f:
config = dict(yaml.load(f.read(), Loader=yaml.FullLoader))

# Give warnings in case the user needs to access the paths
for key in config:
if not os.path.exists(config[key]):
print(f"[Warning]: {key} path {config[key]} does not exist!")

assert (
query_key in config
), f"Key {query_key} not found in config file {config_file}. You need to modify it. Available keys are: {config.keys()}"
Expand All @@ -52,45 +50,25 @@ def get_libero_path(query_key):

def set_libero_default_path(custom_location=os.path.dirname(os.path.abspath(__file__))):
print(
f"[Warning] You are changing the default path for Libero config. This will affect all the paths in the config file."
"[Warning] You are changing the default path for Libero config. This will affect all the paths in the config file."
)
new_config = get_default_path_dict(custom_location)
with open(config_file, "w") as f:
yaml.dump(new_config, f)


if not os.path.exists(libero_config_path):
os.makedirs(libero_config_path)
with FileLock(config_file_lock):
if not os.path.exists(libero_config_path):
os.makedirs(libero_config_path)

if not os.path.exists(config_file):
# Create a default config file

default_path_dict = get_default_path_dict()
answer = input(
"Do you want to specify a custom path for the dataset folder? (Y/N): "
).lower()
if answer == "y":
# If the user wants to specify a custom storage path, prompt them to enter it
custom_dataset_path = input(
"Enter the path where you want to store the datasets: "
)
full_custom_dataset_path = os.path.join(
os.path.abspath(os.path.expanduser(custom_dataset_path)), "datasets"
)
# Check if the custom storage path exists, and create if it doesn't

print("The full path of the custom storage path you entered is:")
print(full_custom_dataset_path)
print("Do you want to continue? (Y/N)")
confirm_answer = input().lower()
if confirm_answer == "y":
if not os.path.exists(full_custom_dataset_path):
os.makedirs(full_custom_dataset_path)
default_path_dict["datasets"] = full_custom_dataset_path
print("Initializing the default config file...")
print(f"The following information is stored in the config file: {config_file}")
# write all the paths into a yaml file
with open(config_file, "w") as f:
yaml.dump(default_path_dict, f)
for key, value in default_path_dict.items():
print(f"{key}: {value}")
with FileLock(config_file_lock):
with open(config_file, "w") as f:
yaml.dump(default_path_dict, f)
for key, value in default_path_dict.items():
print(f"{key}: {value}")
2 changes: 1 addition & 1 deletion libero/libero/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_task_init_states(self, i):
self.tasks[i].problem_folder,
self.tasks[i].init_states_file,
)
init_states = torch.load(init_states_path)
init_states = torch.load(init_states_path, weights_only=False)
return init_states

def set_task_embs(self, task_embs):
Expand Down
2 changes: 1 addition & 1 deletion libero/lifelong/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def main():
init_states_path = os.path.join(
cfg.init_states_folder, task.problem_folder, task.init_states_file
)
init_states = torch.load(init_states_path)
init_states = torch.load(init_states_path, weights_only=False)
indices = np.arange(env_num) % init_states.shape[0]
init_states_ = init_states[indices]

Expand Down
2 changes: 1 addition & 1 deletion libero/lifelong/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def evaluate_one_task_success(
init_states_path = os.path.join(
cfg.init_states_folder, task.problem_folder, task.init_states_file
)
init_states = torch.load(init_states_path)
init_states = torch.load(init_states_path, weights_only=False)
num_success = 0
for i in range(eval_loop_num):
env.reset()
Expand Down
2 changes: 1 addition & 1 deletion libero/lifelong/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def torch_save_model(model, model_path, cfg=None, previous_masks=None):


def torch_load_model(model_path, map_location=None):
model_dict = torch.load(model_path, map_location=map_location)
model_dict = torch.load(model_path, weights_only=False, map_location=map_location)
cfg = None
if "cfg" in model_dict:
cfg = model_dict["cfg"]
Expand Down