Skip to content

Commit 74c4792

Browse files
authored
[Fix] Llama3 Load/Omit CheckpointIO Temporarily (#5717)
* Fix Llama3 Load error * Omit Checkpoint IO Temporarily
1 parent 5bbab15 commit 74c4792

File tree

3 files changed

+72
-69
lines changed

3 files changed

+72
-69
lines changed

colossalai/inference/core/engine.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from colossalai.inference.sampler import search_tokens
2525
from colossalai.inference.spec import Drafter, GlideInput
2626
from colossalai.inference.struct import Sequence
27-
from colossalai.inference.utils import get_model_size, has_index_file
27+
from colossalai.inference.utils import get_model_size
2828
from colossalai.interface import ModelWrapper
2929
from colossalai.logging import get_dist_logger
3030
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -113,18 +113,15 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
113113
model_policy (Policy): the policy to replace the model
114114
"""
115115

116-
casuallm = None
117116
if isinstance(model_or_path, str):
118117
try:
119118
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
120119
arch = getattr(hf_config, "architectures")[0]
121120
if arch in _supported_models.keys():
122-
casuallm = _supported_models[arch](hf_config)
123-
if isinstance(casuallm, AutoModelForCausalLM):
124-
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
125-
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
126-
else:
127-
model = _supported_models[arch](hf_config)
121+
# NOTE(lry89757) Currently we load the model using transformers-api,
122+
# but we will use lazy tensor and checkpoint io to accelerate
123+
# the model load process in the future.
124+
model = _supported_models[arch].from_pretrained(model_or_path, trust_remote_code=True)
128125
else:
129126
raise ValueError(f"Model {arch} is not supported.")
130127

@@ -175,13 +172,14 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
175172
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
176173
)
177174

178-
if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
179-
from colossalai.inference.core.plugin import InferCheckpoint_io
175+
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
176+
# if isinstance(model_or_path, str) and not isinstance(casuallm, AutoModelForCausalLM):
177+
# from colossalai.inference.core.plugin import InferCheckpoint_io
180178

181-
cpt_io = InferCheckpoint_io()
182-
if_has_index_file, model_index_file = has_index_file(model_or_path)
183-
assert if_has_index_file, "the model path is invalid"
184-
cpt_io.load_model(self.model, model_index_file)
179+
# cpt_io = InferCheckpoint_io()
180+
# if_has_index_file, model_index_file = has_index_file(model_or_path)
181+
# assert if_has_index_file, "the model path is invalid"
182+
# cpt_io.load_model(self.model, model_index_file)
185183

186184
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
187185
peak_memory = init_gpu_memory - free_gpu_memory

colossalai/inference/executor/rpc_worker.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from typing import List, Tuple, Union
32

43
import rpyc
@@ -19,7 +18,7 @@
1918
model_policy_map,
2019
)
2120
from colossalai.inference.sampler import search_tokens
22-
from colossalai.inference.utils import get_model_size, has_index_file
21+
from colossalai.inference.utils import get_model_size
2322
from colossalai.interface import ModelWrapper
2423
from colossalai.logging import get_dist_logger
2524
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -178,15 +177,19 @@ def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
178177
"""
179178

180179
if isinstance(model_or_path, str):
181-
is_local = os.path.isdir(model_or_path)
180+
# is_local = os.path.isdir(model_or_path)
182181
try:
183182
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
184183
arch = getattr(hf_config, "architectures")[0]
185-
if is_local:
186-
model = _SUPPORTED_MODELS[arch](hf_config)
187-
else:
188-
# load the real checkpoint
189-
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
184+
# NOTE(lry89757) Currently we load the model using transformers-api,
185+
# but we will use lazy tensor and checkpoint io to accelerate
186+
# the model load process in the future.
187+
model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
188+
# if is_local:
189+
# model = _SUPPORTED_MODELS[arch](hf_config)
190+
# else:
191+
# # load the real checkpoint
192+
# model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True)
190193
except Exception as e:
191194
logger.error(
192195
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
@@ -235,13 +238,14 @@ def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
235238
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
236239
)
237240

238-
if isinstance(model_or_path, str) and is_local:
239-
from colossalai.inference.core.plugin import InferCheckpoint_io
241+
# NOTE(lry89757) Deprecated currently, will reused when introduce lazy tensor
242+
# if isinstance(model_or_path, str) and is_local:
243+
# from colossalai.inference.core.plugin import InferCheckpoint_io
240244

241-
cpt_io = InferCheckpoint_io()
242-
if_has_index_file, model_index_file = has_index_file(model_or_path)
243-
assert if_has_index_file, "the model path is invalid"
244-
cpt_io.load_model(self.model, model_index_file)
245+
# cpt_io = InferCheckpoint_io()
246+
# if_has_index_file, model_index_file = has_index_file(model_or_path)
247+
# assert if_has_index_file, "the model path is invalid"
248+
# cpt_io.load_model(self.model, model_index_file)
245249

246250
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
247251
peak_memory = init_gpu_memory - free_gpu_memory

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -646,48 +646,49 @@ def forward(
646646
def _load_from_state_dict(
647647
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
648648
):
649-
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
650-
for hook in self._load_state_dict_pre_hooks.values():
651-
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
652-
653-
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
654-
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
655-
local_state = {k: v for k, v in local_name_params if v is not None}
656-
657-
key = "qkv_weight"
658-
k1 = "q_proj.weight"
659-
k2 = "k_proj.weight"
660-
k3 = "v_proj.weight"
661-
q_w = state_dict[prefix + k1]
662-
k_w = state_dict[prefix + k2]
663-
v_w = state_dict[prefix + k3]
664-
665-
device_mesh = self.helper_layout.device_mesh
666-
sharding_spec = self.helper_layout.sharding_spec
667-
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
668-
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
669-
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
670-
671-
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
672-
673-
input_param = nn.Parameter(
674-
qkv_w
675-
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
676-
677-
param = local_state[key]
678-
679-
try:
680-
with torch.no_grad():
681-
param.copy_(input_param)
682-
except Exception as ex:
683-
error_msgs.append(
684-
'While copying the parameter named "{}", '
685-
"whose dimensions in the model are {} and "
686-
"whose dimensions in the checkpoint are {}, "
687-
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
688-
)
649+
if self.num_heads == self.num_key_value_heads:
650+
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
651+
for hook in self._load_state_dict_pre_hooks.values():
652+
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
653+
654+
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
655+
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
656+
local_state = {k: v for k, v in local_name_params if v is not None}
657+
658+
key = "qkv_weight"
659+
k1 = "q_proj.weight"
660+
k2 = "k_proj.weight"
661+
k3 = "v_proj.weight"
662+
q_w = state_dict[prefix + k1]
663+
k_w = state_dict[prefix + k2]
664+
v_w = state_dict[prefix + k3]
665+
666+
device_mesh = self.helper_layout.device_mesh
667+
sharding_spec = self.helper_layout.sharding_spec
668+
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
669+
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
670+
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
671+
672+
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
673+
674+
input_param = nn.Parameter(
675+
qkv_w
676+
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
677+
678+
param = local_state[key]
679+
680+
try:
681+
with torch.no_grad():
682+
param.copy_(input_param)
683+
except Exception as ex:
684+
error_msgs.append(
685+
'While copying the parameter named "{}", '
686+
"whose dimensions in the model are {} and "
687+
"whose dimensions in the checkpoint are {}, "
688+
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
689+
)
689690

690-
strict = False # to avoid unexpected_keys
691+
strict = False # to avoid unexpected_keys
691692
super()._load_from_state_dict(
692693
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
693694
)

0 commit comments

Comments
 (0)