@@ -167,8 +167,8 @@ def forward(
167167 spatial_shape = hidden_states .shape [- 2 :]
168168 spatial_noise = torch .randn (
169169 spatial_shape , generator = generator , device = hidden_states .device , dtype = hidden_states .dtype
170- )
171- hidden_states = hidden_states + (spatial_noise * self .per_channel_scale1 )[None , :, None , :, : ]
170+ )[ None ]
171+ hidden_states = hidden_states + (spatial_noise * self .per_channel_scale1 )[None , :, None , ... ]
172172
173173 hidden_states = self .norm2 (hidden_states .movedim (1 , - 1 )).movedim (- 1 , 1 )
174174
@@ -183,8 +183,8 @@ def forward(
183183 spatial_shape = hidden_states .shape [- 2 :]
184184 spatial_noise = torch .randn (
185185 spatial_shape , generator = generator , device = hidden_states .device , dtype = hidden_states .dtype
186- )
187- hidden_states = hidden_states + (spatial_noise * self .per_channel_scale2 )[None , :, None , :, : ]
186+ )[ None ]
187+ hidden_states = hidden_states + (spatial_noise * self .per_channel_scale2 )[None , :, None , ... ]
188188
189189 if self .norm3 is not None :
190190 inputs = self .norm3 (inputs .movedim (1 , - 1 )).movedim (- 1 , 1 )
0 commit comments