Skip to content

Commit e211a93

Browse files
carmoccalantiga
authored andcommitted
Bitsandbytes docs improvements (#18903)
(cherry picked from commit ad93f64)
1 parent 6b87dbc commit e211a93

File tree

4 files changed

+37
-8
lines changed

4 files changed

+37
-8
lines changed

docs/source-fabric/api/fabric_args.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while "
144144
fabric = Fabric(precision="64-true", devices=1)
145145
146146
147+
Precision settings can also be enabled via the plugins argument (see section below on plugins).
148+
An example is the weights quantization plugin Bitsandbytes for 4-bit and 8-bit:
149+
150+
.. code-block:: python
151+
152+
from lightning.fabric.plugins import BitsandbytesPrecision
153+
154+
precision = BitsandbytesPrecision(mode="nf4-dq", dtype=torch.bfloat16)
155+
fabric = Fabric(plugins=precision)
156+
157+
147158
plugins
148159
=======
149160

docs/source-fabric/fundamentals/precision.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ See also: :doc:`../advanced/model_init`
214214
Quantization via Bitsandbytes
215215
*****************************
216216

217-
`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing Linear weights.
217+
`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing :class:`torch.nn.Linear` weights.
218218

219219
Both 4-bit (`paper reference <https://arxiv.org/abs/2305.14314v1>`__) and 8-bit (`paper reference <https://arxiv.org/abs/2110.02861>`__) quantization is supported.
220220
Specifically, we support the following modes:
@@ -228,20 +228,22 @@ Specifically, we support the following modes:
228228

229229
While these techniques store weights in 4 or 8 bit, the computation still happens in 16 or 32-bit (float16, bfloat16, float32).
230230
This is configurable via the dtype argument in the plugin.
231+
If your model weights can fit on a single device with 16 bit precision, it's recommended that this plugin is not used as it will slow down training.
231232

232233
Quantizing the model will dramatically reduce the weight's memory requirements but may have a negative impact on the model's performance or runtime.
233234

234-
The :class:`~lightning.fabric.plugins.precision.bitsandbytes.BitsandbytesPrecision` a
235+
The :class:`~lightning.fabric.plugins.precision.bitsandbytes.BitsandbytesPrecision` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives.
236+
235237
.. code-block:: python
236238
237239
from lightning.fabric.plugins import BitsandbytesPrecision
238240
239241
# this will pick out the compute dtype automatically, by default `bfloat16`
240-
precision = BitsandbytesPrecision("nf4-dq")
242+
precision = BitsandbytesPrecision(mode="nf4-dq")
241243
fabric = Fabric(plugins=precision)
242244
243245
# Customize the dtype, or ignore some modules
244-
precision = BitsandbytesPrecision("int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
246+
precision = BitsandbytesPrecision(mode="int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
245247
fabric = Fabric(plugins=precision)
246248
247249
model = MyModel()

docs/source-fabric/glossary/index.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ Glossary
2828
:button_link: ../advanced/distributed_communication.html
2929
:col_css: col-md-4
3030

31+
.. displayitem::
32+
:header: Bfloat16
33+
:button_link: ../fundamentals/precision.html
34+
:col_css: col-md-4
35+
3136
.. displayitem::
3237
:header: Broadcast
3338
:button_link: ../advanced/distributed_communication.html
@@ -89,7 +94,7 @@ Glossary
8994
:col_css: col-md-4
9095

9196
.. displayitem::
92-
:header: Jypyter
97+
:header: Jupyter
9398
:button_link: ../launch/notebooks.html
9499
:col_css: col-md-4
95100

@@ -148,6 +153,11 @@ Glossary
148153
:button_link: ../fundamentals/precision.html
149154
:col_css: col-md-4
150155

156+
.. displayitem::
157+
:header: Quantization
158+
:button_link: ../fundamentals/precision.html
159+
:col_css: col-md-4
160+
151161
.. displayitem::
152162
:header: Reduce
153163
:button_link: ../advanced/distributed_communication.html
@@ -183,6 +193,11 @@ Glossary
183193
:button_link: ../guide/trainer_template.html
184194
:col_css: col-md-4
185195

196+
.. displayitem::
197+
:header: 16-bit, 8-bit, 4-bit
198+
:button_link: ../fundamentals/precision.html
199+
:col_css: col-md-4
200+
186201

187202
.. raw:: html
188203

docs/source-pytorch/common/precision_intermediate.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ Under the hood, we use `transformer_engine.pytorch.fp8_autocast <https://docs.nv
165165
Quantization via Bitsandbytes
166166
*****************************
167167

168-
`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing Linear weights.
168+
`bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__ (BNB) is a library that supports quantizing :class:`torch.nn.Linear` weights.
169169

170170
Both 4-bit (`paper reference <https://arxiv.org/abs/2305.14314v1>`__) and 8-bit (`paper reference <https://arxiv.org/abs/2110.02861>`__) quantization is supported.
171171
Specifically, we support the following modes:
@@ -179,6 +179,7 @@ Specifically, we support the following modes:
179179

180180
While these techniques store weights in 4 or 8 bit, the computation still happens in 16 or 32-bit (float16, bfloat16, float32).
181181
This is configurable via the dtype argument in the plugin.
182+
If your model weights can fit on a single device with 16 bit precision, it's recommended that this plugin is not used as it will slow down training.
182183

183184
Quantizing the model will dramatically reduce the weight's memory requirements but may have a negative impact on the model's performance or runtime.
184185

@@ -189,11 +190,11 @@ The :class:`~lightning.pytorch.plugins.precision.bitsandbytes.BitsandbytesPrecis
189190
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin
190191
191192
# this will pick out the compute dtype automatically, by default `bfloat16`
192-
precision = BitsandbytesPrecisionPlugin("nf4-dq")
193+
precision = BitsandbytesPrecisionPlugin(mode="nf4-dq")
193194
trainer = Trainer(plugins=precision)
194195
195196
# Customize the dtype, or skip some modules
196-
precision = BitsandbytesPrecisionPlugin("int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
197+
precision = BitsandbytesPrecisionPlugin(mode="int8-training", dtype=torch.float16, ignore_modules={"lm_head"})
197198
trainer = Trainer(plugins=precision)
198199
199200

0 commit comments

Comments
 (0)