Skip to content

Commit 2037199

Browse files
committed
fixing circular import
Summary: was generate.py -> tp.py -> quantize.py -> eval.py -> generate.py I remove the link between generate.py and tp.py until runtime when quantize.py will be fully initialized. note the try/except is still needed for lm_eval stuff in case it is not installed. Also removed issues with initializing tasks multiple times when the new line in generate.py is hit. Test Plan: (with lm_eval 0.3/0.4/not installed) python quantize.py --mode int8 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b3efa96 Pull Request resolved: #97
1 parent 6df1d1f commit 2037199

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from lm_eval.models.huggingface import HFLM as eval_wrapper
3636
from lm_eval.tasks import get_task_dict
3737
from lm_eval.evaluator import evaluate
38-
lm_eval.tasks.initialize_tasks()
3938
except: #lm_eval version 0.3
4039
from lm_eval import base
4140
from lm_eval import tasks
@@ -179,6 +178,11 @@ def eval(
179178
max_seq_length,
180179
)
181180

181+
try:
182+
lm_eval.tasks.initialize_tasks()
183+
except:
184+
pass
185+
182186
if 'hendrycks_test' in tasks:
183187
tasks.remove('hendrycks_test')
184188
tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()]

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def device_sync(device):
3434
from sentencepiece import SentencePieceProcessor
3535

3636
from model import Transformer
37-
from tp import maybe_init_dist
3837

3938

4039
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
@@ -268,6 +267,7 @@ def main(
268267
assert tokenizer_path.is_file(), tokenizer_path
269268

270269
global print
270+
from tp import maybe_init_dist
271271
rank = maybe_init_dist()
272272
use_tp = rank is not None
273273
if use_tp:

quantize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
try:
1515
from GPTQ import GenericGPTQRunner, InputRecorder
16-
from eval import get_task_dict, evaluate
16+
from eval import get_task_dict, evaluate, lm_eval
1717
except:
1818
pass
1919

@@ -249,8 +249,14 @@ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibrati
249249
calibration_seq_length,
250250
pad_calibration_inputs,
251251
)
252+
253+
try:
254+
lm_eval.tasks.initialize_tasks()
255+
except:
256+
pass
252257
task_dict = get_task_dict(calibration_tasks)
253258
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
259+
254260
evaluate(
255261
input_recorder,
256262
task_dict,

0 commit comments

Comments
 (0)