Skip to content

Commit a03dd40

Browse files
[fix] Remove TE v1 from low precision options (#2785)
1 parent ef94b5b commit a03dd40

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

thunder/benchmarks/benchmark_litgpt.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,18 @@ class LowPrecisionHandler:
7979
enabled: bool = False
8080

8181
use_thunder_te: bool = True
82-
use_legacy_thunder_te: bool = False
8382

8483
@property
8584
def use_te_autocast(self) -> bool:
86-
return self.enabled and not self.use_legacy_thunder_te
85+
return self.enabled
8786

8887
def check_and_add_compile_options(self, compile_options: str) -> str:
8988
if not self.enabled and "_transformerengine" in compile_options:
9089
raise ValueError("Low precision mode not specified but found transfomerengine in the compile options!")
9190

9291
self.use_thunder_te = "thunder" in compile_options
93-
self.use_legacy_thunder_te = "transformerengine_v1" in compile_options
9492

95-
if self.enabled and self.use_thunder_te and not self.use_legacy_thunder_te:
93+
if self.enabled and self.use_thunder_te:
9694
compile_options += "_transformerengine"
9795
return compile_options
9896

@@ -120,15 +118,13 @@ def __post_init__(self) -> None:
120118
def executor_str(self) -> str:
121119
if not self.enabled:
122120
return "low precision is not enabled"
123-
elif self.use_legacy_thunder_te:
124-
return "Thunder TE executor v1"
125121
elif self.use_thunder_te:
126122
return "Thunder TE executor"
127123
else:
128124
return "TransformerEngine without Thunder"
129125

130126
def maybe_apply_te_autocast(self):
131-
if self.enabled and not self.use_legacy_thunder_te:
127+
if self.enabled:
132128
return te.fp8_autocast(fp8_recipe=self.fp8_recipe)
133129
else:
134130
return nullcontext()
@@ -160,7 +156,7 @@ def update_if_not_divisible(attr_name, divisor):
160156
print("No updates were necessary.")
161157

162158
def swap_linear_layers_for_te(self, model: torch.nn.Module, device: Any) -> None:
163-
if not self.enabled or self.use_thunder_te:
159+
if not self.enabled:
164160
return
165161

166162
swap_layernorm = self.mode == "fp8-default-te-wo_layernorm"

0 commit comments

Comments
 (0)