@@ -238,7 +238,7 @@ def __init__(
238238
239239        # setable values 
240240        self .num_inference_steps  =  None 
241-         self .timesteps  =  torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].copy ().astype (np .int32 ))
241+         self .timesteps  =  torch .from_numpy (np .arange (0 , num_train_timesteps )[::- 1 ].copy ().astype (np .int64 ))
242242
243243    # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input 
244244    def  scale_model_input (self , sample : torch .Tensor , timestep : Optional [int ] =  None ) ->  torch .Tensor :
@@ -341,19 +341,19 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
341341                np .linspace (0 , self .config .num_train_timesteps  -  1 , num_inference_steps )
342342                .round ()[::- 1 ]
343343                .copy ()
344-                 .astype (np .int32 )
344+                 .astype (np .int64 )
345345            )
346346        elif  self .config .timestep_spacing  ==  "leading" :
347347            step_ratio  =  self .config .num_train_timesteps  //  self .num_inference_steps 
348348            # creates integer timesteps by multiplying by ratio 
349349            # casting to int to avoid issues when num_inference_step is power of 3 
350-             timesteps  =  (np .arange (0 , num_inference_steps ) *  step_ratio ).round ()[::- 1 ].copy ().astype (np .int32 )
350+             timesteps  =  (np .arange (0 , num_inference_steps ) *  step_ratio ).round ()[::- 1 ].copy ().astype (np .int64 )
351351            timesteps  +=  self .config .steps_offset 
352352        elif  self .config .timestep_spacing  ==  "trailing" :
353353            step_ratio  =  self .config .num_train_timesteps  /  self .num_inference_steps 
354354            # creates integer timesteps by multiplying by ratio 
355355            # casting to int to avoid issues when num_inference_step is power of 3 
356-             timesteps  =  np .round (np .arange (self .config .num_train_timesteps , 0 , - step_ratio )).astype (np .int32 )
356+             timesteps  =  np .round (np .arange (self .config .num_train_timesteps , 0 , - step_ratio )).astype (np .int64 )
357357            timesteps  -=  1 
358358        else :
359359            raise  ValueError (
0 commit comments