Skip to content
29 changes: 24 additions & 5 deletions litgpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,30 @@
import logging
import re

from litgpt.api import LLM
from litgpt.config import Config
from litgpt.model import GPT # needs to be imported before config
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer

def __getattr__(name):
if name == "LLM":
from litgpt.api import LLM

return LLM
elif name == "Config":
from litgpt.config import Config

return Config
elif name == "GPT":
from litgpt.model import GPT

return GPT
elif name == "PromptStyle":
from litgpt.prompts import PromptStyle

return PromptStyle
elif name == "Tokenizer":
from litgpt.tokenizer import Tokenizer

return Tokenizer
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
pattern = re.compile(".*Profiler function .* will be ignored")
Expand Down
3 changes: 2 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import Any, List, Literal, Optional, Type, Union

import torch
import yaml
from typing_extensions import Self

Expand Down Expand Up @@ -185,6 +184,8 @@ def norm_class(self) -> Type:

from functools import partial

import torch

if self.norm_class_name == "RMSNorm":
from litgpt.model import RMSNorm

Expand Down
Loading