diff --git a/litgpt/config.py b/litgpt/config.py index cba52e3374..be9e568b13 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Any, List, Literal, Optional, Type, Union -import torch import yaml from typing_extensions import Self @@ -185,6 +184,8 @@ def norm_class(self) -> Type: from functools import partial + import torch # Torch import is lazy to make config loading faster + if self.norm_class_name == "RMSNorm": from litgpt.model import RMSNorm