Skip to content

Commit 5212a0f

Browse files
Edenzzzz's fix for min_8bit_size functionality in Optimizer base classes (#1286)
* fix min_8bit_size invalid bug * Apply same fix to other optimizer base class --------- Co-authored-by: Edenzzzz <[email protected]>
1 parent 0bdd57c commit 5212a0f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def init_state(self, group, p, gindex, pindex):
437437
state = self.state[p]
438438
state["step"] = 0
439439

440-
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
440+
if dtype == torch.float32:
441441
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
442442
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
443443
elif dtype == torch.uint8:
@@ -656,7 +656,7 @@ def init_state(self, group, p, gindex, pindex):
656656
state = self.state[p]
657657
state["step"] = 0
658658

659-
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
659+
if dtype == torch.float32:
660660
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
661661
elif dtype == torch.uint8:
662662
if state["step"] == 0:

0 commit comments

Comments
 (0)