Skip to content

Commit 2bad2fc

Browse files
DorsaRohsayakpaul
andauthored
Update examples/reinforcement_learning/diffusion_policy.py
Co-authored-by: Sayak Paul <[email protected]>
1 parent 5bff039 commit 2bad2fc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/reinforcement_learning/diffusion_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, state_dim=5, device="cuda" if torch.cuda.is_available() else
9494
)
9595

9696
# load pre-trained weights from HuggingFace
97-
checkpoint = torch.load(hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), map_location=device)
97+
checkpoint = torch.load(hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device)
9898

9999
self.model.load_state_dict(checkpoint['model_state_dict'])
100100
self.obs_encoder.load_state_dict(checkpoint['encoder_state_dict'])

0 commit comments

Comments
 (0)