Skip to content

Commit a7b8ef2

Browse files
committed
suggestions + safe globals for weights_only=True
1 parent be98e76 commit a7b8ef2

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

examples/reinforcement_learning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
## Diffusion-based Policy Learning for RL
33

4-
`diffusion_policy` implements <a href="https://diffusion-policy.cs.columbia.edu/">Diffusion Policy</a>, a diffusion model that predicts robot action sequences in reinforcement learning tasks.
4+
`diffusion_policy` implements [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/), a diffusion model that predicts robot action sequences in reinforcement learning tasks.
55

66
This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow.
77

examples/reinforcement_learning/diffusion_policy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import torch
33
import torch.nn as nn
44
from huggingface_hub import hf_hub_download
5-
65
from diffusers import DDPMScheduler, UNet1DModel
7-
6+
from torch.serialization import add_safe_globals
7+
import numpy.core.multiarray as multiarray
8+
add_safe_globals([multiarray._reconstruct, np.ndarray, np.dtype, np.dtype(np.float32).type, np.dtype(np.float64).type, np.dtype(np.int32).type, np.dtype(np.int64).type, type(np.dtype(np.float32)), type(np.dtype(np.float64)), type(np.dtype(np.int32)), type(np.dtype(np.int64))])
89

910
"""
1011
An example of using HuggingFace's diffusers library for diffusion policy,
@@ -65,7 +66,7 @@ class DiffusionPolicy:
6566
The model expects observations in pixel coordinates (0-512 range) and block angle in radians.
6667
It generates trajectories as sequences of (x,y) coordinates also in the 0-512 range.
6768
"""
68-
def __init__(self, state_dim=5, device="cuda" if torch.cuda.is_available() else "cpu"):
69+
def __init__(self, state_dim=5, device="cpu"):
6970
self.device = device
7071

7172
# define valid ranges for inputs/outputs
@@ -94,8 +95,7 @@ def __init__(self, state_dim=5, device="cuda" if torch.cuda.is_available() else
9495
)
9596

9697
# load pre-trained weights from HuggingFace
97-
checkpoint = torch.load(hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), map_location=device)
98-
98+
checkpoint = torch.load(hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device)
9999
self.model.load_state_dict(checkpoint['model_state_dict'])
100100
self.obs_encoder.load_state_dict(checkpoint['encoder_state_dict'])
101101
self.obs_projection.load_state_dict(checkpoint['projection_state_dict'])

0 commit comments

Comments
 (0)