Skip to content

Commit c8fa61a

Browse files
committed
fix code quality
1 parent 19db70e commit c8fa61a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

examples/reinforcement_learning/diffusion_policy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import numpy as np
2+
import numpy.core.multiarray as multiarray
23
import torch
34
import torch.nn as nn
45
from huggingface_hub import hf_hub_download
5-
from diffusers import DDPMScheduler, UNet1DModel
66
from torch.serialization import add_safe_globals
7-
import numpy.core.multiarray as multiarray
7+
8+
from diffusers import DDPMScheduler, UNet1DModel
9+
10+
811
add_safe_globals([multiarray._reconstruct, np.ndarray, np.dtype, np.dtype(np.float32).type, np.dtype(np.float64).type, np.dtype(np.int32).type, np.dtype(np.int64).type, type(np.dtype(np.float32)), type(np.dtype(np.float64)), type(np.dtype(np.int32)), type(np.dtype(np.int64))])
912

1013
"""

0 commit comments

Comments
 (0)