Skip to content

Commit 332d98a

Browse files
committed
.
1 parent 935dd70 commit 332d98a

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

auto_fp8/modeling.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ def from_pretrained(
2828
quantize_config: BaseQuantizeConfig,
2929
**model_init_kwargs,
3030
):
31-
"""load un-quantized pretrained model to cpu"""
31+
"""Load the un-quantized pretrained model"""
3232

33-
if not torch.cuda.is_available():
34-
raise EnvironmentError(
35-
"Load pretrained model to do quantization requires CUDA available."
36-
)
33+
# if not torch.cuda.is_available():
34+
# raise EnvironmentError(
35+
# "Load pretrained model to do quantization requires CUDA available."
36+
# )
3737

3838
def skip(*args, **kwargs):
3939
pass
@@ -88,9 +88,7 @@ def skip(*args, **kwargs):
8888
model.seqlen = model_config[key]
8989
break
9090
else:
91-
print(
92-
"can't get model's sequence length from model config, will set to 2048."
93-
)
91+
print("Can't get model's sequence length, setting to 2048.")
9492
model.seqlen = 2048
9593
model.eval()
9694

0 commit comments

Comments
 (0)