From 76a28fc538d17a80e78656bf80da9048f070a95d Mon Sep 17 00:00:00 2001 From: Talmeez Fuaad <87268503+itstalmeez@users.noreply.github.com> Date: Sun, 26 Nov 2023 23:18:20 +0500 Subject: [PATCH] Update sample.py --- grade_school_math/sample.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/grade_school_math/sample.py b/grade_school_math/sample.py index 1c3f14d..2350b2c 100644 --- a/grade_school_math/sample.py +++ b/grade_school_math/sample.py @@ -2,21 +2,28 @@ from dataset import get_examples, GSMDataset from calculator import sample from transformers import GPT2Tokenizer, GPT2LMHeadModel - +import logging def main(): - device = th.device("cuda") - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - model = GPT2LMHeadModel.from_pretrained("model_ckpts") - model.to(device) - print("Model Loaded") + try: + # Set up logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + device = th.device("cuda") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = GPT2LMHeadModel.from_pretrained("model_ckpts") + model.to(device) + logger.info("Model Loaded") - test_examples = get_examples("test") - qn = test_examples[1]["question"] - sample_len = 100 - print(qn.strip()) - print(sample(model, qn, tokenizer, device, sample_len)) + test_examples = get_examples("test") + qn = test_examples[1]["question"] + sample_len = 100 + logger.info(qn.strip()) + logger.info(sample(model, qn, tokenizer, device, sample_len)) + except Exception as e: + logging.exception("An error occurred: %s", str(e)) if __name__ == "__main__": main()