Skip to content

Commit d147ad0

Browse files
authored
[#2730][fix] Fix circular import bug in medusa/weight.py (#9866)
Signed-off-by: Kanghwan Jang <[email protected]>
1 parent 454e7e5 commit d147ad0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorrt_llm/models/medusa/weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
1212
from transformers.pytorch_utils import Conv1D
1313

14-
from tensorrt_llm import logger
1514
from tensorrt_llm._utils import str_dtype_to_torch
15+
from tensorrt_llm.logger import logger
1616
from tensorrt_llm.mapping import Mapping
1717
from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8,
1818
smooth_gemm,
@@ -51,7 +51,7 @@ def load_medusa_hf(medusa_path: str,
5151
use_weight_only=False,
5252
plugin_weight_only_quant_type=None,
5353
is_modelopt_ckpt=False):
54-
# logger.info("Loading Medusa heads' weights ...")
54+
logger.info("Loading Medusa heads' weights ...")
5555

5656
if is_modelopt_ckpt:
5757
from safetensors.torch import load_file

0 commit comments

Comments
 (0)