Skip to content

Commit babfb8a

Browse files
authored
[MPS] call contiguous after permute (#1411)
* call contiguous after permute Fixes for MPS device * Fix MPS UserWarning * make style * Revert "Fix MPS UserWarning" This reverts commit b46c328.
1 parent 35099b2 commit babfb8a

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/models/attention.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,15 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu
221221
# 3. Output
222222
if self.is_input_continuous:
223223
if not self.use_linear_projection:
224-
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
224+
hidden_states = (
225+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
226+
)
225227
hidden_states = self.proj_out(hidden_states)
226228
else:
227229
hidden_states = self.proj_out(hidden_states)
228-
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
230+
hidden_states = (
231+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
232+
)
229233

230234
output = hidden_states + residual
231235
elif self.is_input_vectorized:

0 commit comments

Comments
 (0)