Skip to content

Commit 2db8b8e

Browse files
committed
forward XLA precision
1 parent 61f01c6 commit 2db8b8e

File tree

3 files changed

+11
-30
lines changed

3 files changed

+11
-30
lines changed

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
BitsandbytesPrecision as EnterpriseBitsandbytesPrecision,
5858
)
5959

60-
self.bitsandbytes_impl = EnterpriseBitsandbytesPrecision(mode, dtype, ignore_modules)
60+
self.bitsandbytes_impl = EnterpriseBitsandbytesPrecision(mode=mode, dtype=dtype, ignore_modules=ignore_modules)
6161

6262
@override
6363
def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:

src/lightning/fabric/plugins/precision/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
4545
super().__init__()
4646
_raise_enterprise_not_available()
4747
from pytorch_lightning_enterprise.fabric.plugins.precision.deepspeed import (
48-
DeepSpeedPrecision as EnterpriseDeepSpeedPrecision,
48+
DeepSpeedPrecisionFabric as EnterpriseDeepSpeedPrecision,
4949
)
5050

51-
self.deepspeed_impl = EnterpriseDeepSpeedPrecision(precision)
51+
self.deepspeed_impl = EnterpriseDeepSpeedPrecision(precision=precision)
5252

5353
@override
5454
def convert_module(self, module: Module) -> Module:

src/lightning/fabric/plugins/precision/xla.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from typing import Any, Literal
1615

17-
import torch
18-
from typing_extensions import get_args, override
16+
from typing_extensions import override
1917

20-
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
2118
from lightning.fabric.plugins.precision.precision import Precision
19+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2220
from lightning.fabric.utilities.types import Optimizable
2321

2422
_PRECISION_INPUT = Literal["32-true", "16-true", "bf16-true"]
@@ -37,37 +35,20 @@ class XLAPrecision(Precision):
3735
"""
3836

3937
def __init__(self, precision: _PRECISION_INPUT) -> None:
40-
if not _XLA_AVAILABLE:
41-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
42-
supported_precision = get_args(_PRECISION_INPUT)
43-
if precision not in supported_precision:
44-
raise ValueError(
45-
f"`precision={precision!r})` is not supported in XLA."
46-
f" `precision` must be one of: {supported_precision}."
47-
)
48-
self.precision = precision
38+
super().__init__()
39+
_raise_enterprise_not_available()
40+
from pytorch_lightning_enterprise.fabric.plugins.precision.xla import XLAPrecision as EnterpriseXLAPrecision
4941

50-
if precision == "16-true":
51-
os.environ["XLA_USE_F16"] = "1"
52-
self._desired_dtype = torch.float16
53-
elif precision == "bf16-true":
54-
os.environ["XLA_USE_BF16"] = "1"
55-
self._desired_dtype = torch.bfloat16
56-
else:
57-
self._desired_dtype = torch.float32
42+
self.xla_impl = EnterpriseXLAPrecision(precision=precision)
5843

5944
@override
6045
def optimizer_step(
6146
self,
6247
optimizer: Optimizable,
6348
**kwargs: Any,
6449
) -> Any:
65-
import torch_xla.core.xla_model as xm
66-
67-
# you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
68-
return xm.optimizer_step(optimizer, optimizer_args=kwargs, barrier=True)
50+
return self.xla_impl.optimizer_step(optimizer, **kwargs)
6951

7052
@override
7153
def teardown(self) -> None:
72-
os.environ.pop("XLA_USE_BF16", None)
73-
os.environ.pop("XLA_USE_F16", None)
54+
return self.xla_impl.teardown()

0 commit comments

Comments
 (0)