Skip to content

Commit a3f55ce

Browse files
Fixed optim update error with non-contiguous grads/params (#1187)
* Fixed optim update error with non-contiguous grads * fix formatting Thanks @Edenzzzz for this contribution! --------- Co-authored-by: Titus von Koeller <[email protected]>
1 parent 5212a0f commit a3f55ce

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,10 @@ def init_state(self, group, p, gindex, pindex):
474474

475475
@torch.no_grad()
476476
def update_step(self, group, p, gindex, pindex):
477+
# avoid update error from non-contiguous memory layout
478+
p.data = p.data.contiguous()
479+
p.grad = p.grad.contiguous()
480+
477481
state = self.state[p]
478482
grad = p.grad
479483

@@ -685,6 +689,10 @@ def init_state(self, group, p, gindex, pindex):
685689

686690
@torch.no_grad()
687691
def update_step(self, group, p, gindex, pindex):
692+
# avoid update error from non-contiguous memory layout
693+
p.data = p.data.contiguous()
694+
p.grad = p.grad.contiguous()
695+
688696
state = self.state[p]
689697
grad = p.grad
690698

0 commit comments

Comments
 (0)