Skip to content

Commit 19db70e

Browse files
committed
suggestions + safe weights loading
1 parent 9a5cfe7 commit 19db70e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/reinforcement_learning/diffusion_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def __init__(self, state_dim=5, device="cpu"):
9595
)
9696

9797
# load pre-trained weights from HuggingFace
98-
checkpoint = torch.load(hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), map_location=device)
99-
98+
checkpoint = torch.load(hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device)
10099
self.model.load_state_dict(checkpoint['model_state_dict'])
101100
self.obs_encoder.load_state_dict(checkpoint['encoder_state_dict'])
102101
self.obs_projection.load_state_dict(checkpoint['projection_state_dict'])
@@ -173,3 +172,4 @@ def predict(self, observation):
173172
print("\nPredicted trajectory:")
174173
for i, (x, y) in enumerate(action[0]):
175174
print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}")
175+

0 commit comments

Comments
 (0)