Skip to content
This repository was archived by the owner on Nov 16, 2021. It is now read-only.

Commit bc98220

Browse files
committed
add warmup
Signed-off-by: Bedapudi Praneeth <[email protected]>
1 parent dc66268 commit bc98220

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

fastpunct/fastpunct.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import torch
3+
import logging
34
import pydload
45
from transformers import T5Tokenizer, T5ForConditionalGeneration
56

@@ -23,7 +24,7 @@ def __init__(self, language='english', checkpoint_local_path=None):
2324
model_name = language.lower()
2425

2526
if model_name not in MODEL_URLS:
26-
print(f"model_name should be one of {list(MODEL_URLS.keys())}")
27+
logging.warn(f"model_name should be one of {list(MODEL_URLS.keys())}")
2728
return None
2829

2930
home = os.path.expanduser("~")
@@ -39,7 +40,7 @@ def __init__(self, language='english', checkpoint_local_path=None):
3940
file_path = os.path.join(lang_path, file_name)
4041
if os.path.exists(file_path):
4142
continue
42-
print(f"Downloading {file_name}")
43+
logging.info(f"Downloading {file_name}")
4344
pydload.dload(url=url, save_to_path=file_path, max_time=None)
4445

4546
self.tokenizer = T5Tokenizer.from_pretrained(lang_path)
@@ -48,8 +49,11 @@ def __init__(self, language='english', checkpoint_local_path=None):
4849
)
4950

5051
if torch.cuda.is_available():
51-
print(f"Using GPU")
52+
logging.info(f"Using GPU")
5253
self.model = self.model.cuda()
54+
55+
logging.info("Warming up")
56+
self.punct(["i am batman"])
5357

5458
def punct(
5559
self, sentences, beam_size=1, max_len=None, correct=False

0 commit comments

Comments
 (0)