Skip to content

Commit 811c049

Browse files
committed
fix: torch 2.6 weights_only load
Signed-off-by: Hao Lin <[email protected]>
1 parent 8f1084e commit 811c049

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

libero/libero/benchmark/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_task_init_states(self, i):
161161
self.tasks[i].problem_folder,
162162
self.tasks[i].init_states_file,
163163
)
164-
init_states = torch.load(init_states_path)
164+
init_states = torch.load(init_states_path, weights_only=False)
165165
return init_states
166166

167167
def set_task_embs(self, task_embs):

libero/lifelong/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def main():
251251
init_states_path = os.path.join(
252252
cfg.init_states_folder, task.problem_folder, task.init_states_file
253253
)
254-
init_states = torch.load(init_states_path)
254+
init_states = torch.load(init_states_path, weights_only=False)
255255
indices = np.arange(env_num) % init_states.shape[0]
256256
init_states_ = init_states[indices]
257257

libero/lifelong/metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def evaluate_one_task_success(
104104
init_states_path = os.path.join(
105105
cfg.init_states_folder, task.problem_folder, task.init_states_file
106106
)
107-
init_states = torch.load(init_states_path)
107+
init_states = torch.load(init_states_path, weights_only=False)
108108
num_success = 0
109109
for i in range(eval_loop_num):
110110
env.reset()

libero/lifelong/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def torch_save_model(model, model_path, cfg=None, previous_masks=None):
5656

5757

5858
def torch_load_model(model_path, map_location=None):
59-
model_dict = torch.load(model_path, map_location=map_location)
59+
model_dict = torch.load(model_path, weights_only=False, map_location=map_location)
6060
cfg = None
6161
if "cfg" in model_dict:
6262
cfg = model_dict["cfg"]

0 commit comments

Comments
 (0)