Skip to content

Commit 65db8ac

Browse files
authored
Fix TE internal imports (#3852)
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent a72834c commit 65db8ac

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/accelerate/utils/imports.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def is_transformer_engine_available():
116116

117117
def is_transformer_engine_mxfp8_available():
118118
if _is_package_available("transformer_engine", "transformer-engine"):
119-
import transformer_engine.pytorch as te
119+
from transformer_engine.pytorch.fp8 import check_mxfp8_support
120120

121-
return te.fp8.check_mxfp8_support()[0]
121+
return check_mxfp8_support()[0]
122122
return False
123123

124124

src/accelerate/utils/transformer_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
154154

155155
else:
156156
import transformer_engine.common.recipe as te_recipe
157-
import transformer_engine.pytorch as te
157+
from transformer_engine.pytorch.fp8 import check_mxfp8_support
158158

159-
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
159+
is_fp8_block_scaling_available, message = check_mxfp8_support()
160160

161161
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
162162
if "fp8_format" in kwargs:

0 commit comments

Comments
 (0)