1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import os
1514from typing import Any , Literal
1615
17- import torch
18- from typing_extensions import get_args , override
16+ from typing_extensions import override
1917
20- from lightning .fabric .accelerators .xla import _XLA_AVAILABLE
2118from lightning .fabric .plugins .precision .precision import Precision
19+ from lightning .fabric .utilities .imports import _raise_enterprise_not_available
2220from lightning .fabric .utilities .types import Optimizable
2321
2422_PRECISION_INPUT = Literal ["32-true" , "16-true" , "bf16-true" ]
@@ -37,37 +35,20 @@ class XLAPrecision(Precision):
3735 """
3836
3937 def __init__ (self , precision : _PRECISION_INPUT ) -> None :
40- if not _XLA_AVAILABLE :
41- raise ModuleNotFoundError (str (_XLA_AVAILABLE ))
42- supported_precision = get_args (_PRECISION_INPUT )
43- if precision not in supported_precision :
44- raise ValueError (
45- f"`precision={ precision !r} )` is not supported in XLA."
46- f" `precision` must be one of: { supported_precision } ."
47- )
48- self .precision = precision
38+ super ().__init__ ()
39+ _raise_enterprise_not_available ()
40+ from pytorch_lightning_enterprise .fabric .plugins .precision .xla import XLAPrecision as EnterpriseXLAPrecision
4941
50- if precision == "16-true" :
51- os .environ ["XLA_USE_F16" ] = "1"
52- self ._desired_dtype = torch .float16
53- elif precision == "bf16-true" :
54- os .environ ["XLA_USE_BF16" ] = "1"
55- self ._desired_dtype = torch .bfloat16
56- else :
57- self ._desired_dtype = torch .float32
42+ self .xla_impl = EnterpriseXLAPrecision (precision = precision )
5843
5944 @override
6045 def optimizer_step (
6146 self ,
6247 optimizer : Optimizable ,
6348 ** kwargs : Any ,
6449 ) -> Any :
65- import torch_xla .core .xla_model as xm
66-
67- # you always want to `xm.mark_step()` after `optimizer.step` for better performance, so we set `barrier=True`
68- return xm .optimizer_step (optimizer , optimizer_args = kwargs , barrier = True )
50+ return self .xla_impl .optimizer_step (optimizer , ** kwargs )
6951
7052 @override
7153 def teardown (self ) -> None :
72- os .environ .pop ("XLA_USE_BF16" , None )
73- os .environ .pop ("XLA_USE_F16" , None )
54+ return self .xla_impl .teardown ()
0 commit comments