Skip to content

Commit 44d68ff

Browse files
committed
Added paged optimizers.
1 parent ec38ba9 commit 44d68ff

File tree

8 files changed

+158
-267
lines changed

8 files changed

+158
-267
lines changed

bitsandbytes/cextension.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
lib.get_context.restype = ct.c_void_p
2828
lib.get_cusparse.restype = ct.c_void_p
2929
lib.cget_managed_ptr.restype = ct.c_void_p
30-
lib.cget_stream.restype = ct.c_void_p
3130
COMPILED_WITH_CUDA = True
3231
except AttributeError:
3332
warn("The installed version of bitsandbytes was compiled without GPU support. "

bitsandbytes/functional.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,27 @@ def prod(iterable):
8383
lib.cadagrad_8bit_blockwise_fp16,
8484
)
8585

86+
class GlobalPageManager:
87+
_instance = None
88+
89+
def __init__(self):
90+
raise RuntimeError("Call get_instance() instead")
91+
92+
def initialize(self):
93+
self.paged_tensors = []
94+
95+
@classmethod
96+
def get_instance(cls):
97+
if cls._instance is None:
98+
cls._instance = cls.__new__(cls)
99+
cls._instance.initialize()
100+
return cls._instance
101+
102+
def prefetch_all(self, to_cpu=False):
103+
for t in self.paged_tensors:
104+
prefetch_tensor(t, to_cpu)
105+
106+
86107

87108
class CUBLAS_Context:
88109
_instance = None
@@ -142,7 +163,7 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0))
142163
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
143164
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
144165
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
145-
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape))
166+
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
146167
out.is_paged = True
147168
out.page_deviceid = device.index
148169
return out
@@ -415,10 +436,14 @@ def is_on_gpu(tensors):
415436
gpu_ids = set()
416437
for t in tensors:
417438
if t is None: continue # NULL pointers are fine
418-
on_gpu &= t.device.type == 'cuda'
419-
gpu_ids.add(t.device.index)
439+
is_paged = getattr(t, 'is_paged', False)
440+
on_gpu &= (t.device.type == 'cuda' or is_paged)
441+
if not is_paged:
442+
gpu_ids.add(t.device.index)
443+
if not on_gpu:
444+
raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}')
420445
if len(gpu_ids) > 1:
421-
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}')
446+
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
422447
return on_gpu
423448

424449
def get_ptr(A: Tensor) -> ct.c_void_p:

bitsandbytes/optim/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from bitsandbytes.cextension import COMPILED_WITH_CUDA
77

88
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
9-
from .adam import Adam, Adam8bit, Adam32bit
10-
from .adamw import AdamW, AdamW8bit, AdamW32bit
9+
from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
10+
from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit
1111
from .lamb import LAMB, LAMB8bit, LAMB32bit
1212
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
1313
from .optimizer import GlobalOptimManager

bitsandbytes/optim/adam.py

Lines changed: 24 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -14,92 +14,34 @@
1414

1515

1616
class Adam(Optimizer2State):
17-
def __init__(
18-
self,
19-
params,
20-
lr=1e-3,
21-
betas=(0.9, 0.999),
22-
eps=1e-8,
23-
weight_decay=0,
24-
amsgrad=False,
25-
optim_bits=32,
26-
args=None,
27-
min_8bit_size=4096,
28-
percentile_clipping=100,
29-
block_wise=True,
30-
):
31-
super().__init__(
32-
"adam",
33-
params,
34-
lr,
35-
betas,
36-
eps,
37-
weight_decay,
38-
optim_bits,
39-
args,
40-
min_8bit_size,
41-
percentile_clipping,
42-
block_wise,
43-
)
44-
17+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
18+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
19+
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
4520

4621
class Adam8bit(Optimizer2State):
47-
def __init__(
48-
self,
49-
params,
50-
lr=1e-3,
51-
betas=(0.9, 0.999),
52-
eps=1e-8,
53-
weight_decay=0,
54-
amsgrad=False,
55-
args=None,
56-
min_8bit_size=4096,
57-
percentile_clipping=100,
58-
block_wise=True,
59-
):
60-
super().__init__(
61-
"adam",
62-
params,
63-
lr,
64-
betas,
65-
eps,
66-
weight_decay,
67-
8,
68-
args,
69-
min_8bit_size,
70-
percentile_clipping,
71-
block_wise,
72-
)
73-
22+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
23+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
24+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
7425

7526
class Adam32bit(Optimizer2State):
76-
def __init__(
77-
self,
78-
params,
79-
lr=1e-3,
80-
betas=(0.9, 0.999),
81-
eps=1e-8,
82-
weight_decay=0,
83-
amsgrad=False,
84-
args=None,
85-
min_8bit_size=4096,
86-
percentile_clipping=100,
87-
block_wise=True,
88-
):
89-
super().__init__(
90-
"adam",
91-
params,
92-
lr,
93-
betas,
94-
eps,
95-
weight_decay,
96-
32,
97-
args,
98-
min_8bit_size,
99-
percentile_clipping,
100-
block_wise,
101-
)
102-
27+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
28+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
29+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
30+
31+
class PagedAdam(Optimizer2State):
32+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
33+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
34+
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
35+
36+
class PagedAdam8bit(Optimizer2State):
37+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
38+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
39+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
40+
41+
class PagedAdam32bit(Optimizer2State):
42+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
43+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
44+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
10345

10446
class AnalysisAdam(torch.optim.Optimizer):
10547
"""Adam that performs 8-bit vs 32-bit error analysis.

bitsandbytes/optim/adamw.py

Lines changed: 27 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,89 +5,35 @@
55
from bitsandbytes.optim.optimizer import Optimizer2State
66

77

8-
class AdamW(Optimizer2State):
9-
def __init__(
10-
self,
11-
params,
12-
lr=1e-3,
13-
betas=(0.9, 0.999),
14-
eps=1e-8,
15-
weight_decay=1e-2,
16-
amsgrad=False,
17-
optim_bits=32,
18-
args=None,
19-
min_8bit_size=4096,
20-
percentile_clipping=100,
21-
block_wise=True,
22-
):
23-
super().__init__(
24-
"adam",
25-
params,
26-
lr,
27-
betas,
28-
eps,
29-
weight_decay,
30-
optim_bits,
31-
args,
32-
min_8bit_size,
33-
percentile_clipping,
34-
block_wise,
35-
)
368

9+
class AdamW(Optimizer2State):
10+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
11+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
12+
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
3713

3814
class AdamW8bit(Optimizer2State):
39-
def __init__(
40-
self,
41-
params,
42-
lr=1e-3,
43-
betas=(0.9, 0.999),
44-
eps=1e-8,
45-
weight_decay=1e-2,
46-
amsgrad=False,
47-
args=None,
48-
min_8bit_size=4096,
49-
percentile_clipping=100,
50-
block_wise=True,
51-
):
52-
super().__init__(
53-
"adam",
54-
params,
55-
lr,
56-
betas,
57-
eps,
58-
weight_decay,
59-
8,
60-
args,
61-
min_8bit_size,
62-
percentile_clipping,
63-
block_wise,
64-
)
65-
15+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
16+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
17+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
6618

6719
class AdamW32bit(Optimizer2State):
68-
def __init__(
69-
self,
70-
params,
71-
lr=1e-3,
72-
betas=(0.9, 0.999),
73-
eps=1e-8,
74-
weight_decay=1e-2,
75-
amsgrad=False,
76-
args=None,
77-
min_8bit_size=4096,
78-
percentile_clipping=100,
79-
block_wise=True,
80-
):
81-
super().__init__(
82-
"adam",
83-
params,
84-
lr,
85-
betas,
86-
eps,
87-
weight_decay,
88-
32,
89-
args,
90-
min_8bit_size,
91-
percentile_clipping,
92-
block_wise,
93-
)
20+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
21+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
22+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
23+
24+
25+
class PagedAdamW(Optimizer2State):
26+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
27+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
28+
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
29+
30+
class PagedAdamW8bit(Optimizer2State):
31+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
32+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
33+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
34+
35+
class PagedAdamW32bit(Optimizer2State):
36+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
37+
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
38+
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
39+

0 commit comments

Comments
 (0)