Skip to content

Commit 7f9f667

Browse files
committed
Update bloom model support
Signed-off-by: char-1ee <[email protected]>
1 parent feee72b commit 7f9f667

File tree

4 files changed

+123
-100
lines changed

4 files changed

+123
-100
lines changed

colossalai/inference/core/engine.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
PreTrainedTokenizerFast,
1515
)
1616
from transformers.models.llama.modeling_llama import LlamaForCausalLM
17+
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
1718

1819
from colossalai.accelerator import get_accelerator
1920
from colossalai.cluster import ProcessGroupMesh
@@ -39,8 +40,11 @@
3940
_supported_models = {
4041
"LlamaForCausalLM": LlamaForCausalLM,
4142
"BaichuanForCausalLM": AutoModelForCausalLM,
43+
"BloomForCausalLM": BloomForCausalLM,
4244
}
4345

46+
_alibi_models = ["bloom", "baichuan"]
47+
4448
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
4549

4650

@@ -79,7 +83,7 @@ def __init__(
7983
self.tokenizer = tokenizer
8084
self.tokenizer.pad_token = self.tokenizer.eos_token
8185

82-
self.request_handler = RequestHandler(self.inference_config, self.model_config)
86+
self.request_handler = RequestHandler(self.inference_config, self.model_config, alibi_attn=self.alibi_attn)
8387
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
8488
# DISCUSS maybe move this into batch info?
8589

@@ -160,6 +164,14 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
160164
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
161165
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
162166

167+
self.alibi_attn = False
168+
if self.model_config.model_type in _alibi_models:
169+
# Used for bloom, baichuan 13b and baichuan2 13b.
170+
self.alibi_attn = True
171+
# Hardcode used to distinguish between baichuan 7b and baichuan 13b.(There might be a better way to handle this.)
172+
if self.model_config.model_type == "baichuan" and self.model_config.hidden_size == 4096:
173+
self.alibi_attn = False
174+
163175
self.model = self._shardformer(
164176
model,
165177
model_policy,
@@ -735,4 +747,4 @@ def step(self) -> List[str]:
735747

736748
finished_sequences = self.request_handler.update()
737749

738-
return finished_sequences
750+
return finished_sequences

colossalai/inference/modeling/models/bloom.py renamed to colossalai/inference/modeling/models/nopadding_bloom.py

Lines changed: 76 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
)
77
from colossalai.inference.flash_decoding_utils import FDIntermTensors
88
from colossalai.shardformer.shard import ShardConfig
9-
from colossalai.kernel.triton import flash_decoding_attention_with_alibi
9+
from colossalai.kernel.triton import flash_decoding_attention, context_attention_unpadded
1010
from colossalai.kernel.kernel_loader import InferenceOpsLoader
1111
from colossalai.kernel.jit.bias_gelu import GeLUFunction
1212
from colossalai.kernel.jit.bias_dropout_add import bias_dropout_add_fused_inference
1313

1414

1515
import torch
1616
import torch.nn.functional as F
17+
import torch.nn as nn
1718
from typing import List, Optional, Tuple
1819
import math
1920

@@ -61,26 +62,9 @@ def _get_alibi_tensor(n_heads: int, mask: torch.Tensor):
6162
return distance[:, :, None] * slopes[None, None, :]
6263

6364

64-
# def _fill_with_neg_inf(t):
65-
# return t.float().fill_(float("-inf")).type_as(t)
66-
67-
# # (Register buffer within BloomModel), only use for inference
68-
# def _get_alibi_tensor(max_pos: int, n_heads: int):
69-
# slopes = torch.Tensor(_get_alibi_slopes(n_heads))
70-
# alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0) \
71-
# .expand(n_heads, -1, -1) \
72-
# .view(n_heads, 1, max_pos)
73-
74-
# alibi_mask = torch.triu (
75-
# _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
76-
# )
77-
# return alibi_mask.unsqueeze(0) + alibi
78-
79-
80-
# TODO
8165
def bloom_model_forward(
8266
self: BloomModel,
83-
input_tokens_ids: torch.Tensor,
67+
input_tokens_ids: torch.Tensor, # no padding
8468
output_tensor: torch.Tensor,
8569
inputmetadata: InputMetaData,
8670
k_caches: List[torch.Tensor] = None,
@@ -89,10 +73,10 @@ def bloom_model_forward(
8973
high_precision: bool = False,
9074
) -> torch.Tensor:
9175

92-
def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = False):
93-
if is_prompts:
94-
is_prompts = False
95-
self.register_buffer("future_mask", _get_alibi_tensor())
76+
# def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = False):
77+
# if is_prompts:
78+
# is_prompts = False
79+
# self.register_buffer("future_mask", _get_alibi_tensor())
9680

9781
is_prompts = inputmetadata.is_prompts
9882
block_tables = inputmetadata.block_tables
@@ -120,7 +104,7 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
120104
# self.max_cache_pos = seq_length_with_past
121105
# self.register_buffer("future_mask", _get_alibi_tensor(self.n_head, self.max_cache_pos).to(hidden_states), persistent=False)
122106

123-
alibi = _get_alibi_slopes(self.n_head)
107+
# alibi = _get_alibi_slopes(self.num_heads)
124108
# alibi_mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
125109

126110
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
@@ -129,7 +113,6 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
129113
for layer_id, layer in enumerate(self.h):
130114
hidden_states = layer(
131115
hidden_states,
132-
alibi=alibi,
133116
block_tables=block_tables,
134117
k_cache=k_caches[layer_id],
135118
v_cache=v_caches[layer_id],
@@ -138,8 +121,6 @@ def get_alibi_mask(x: torch.Tensor, past_seq_length: int, is_prompts: bool = Fal
138121
fd_inter_tensor=inputmetadata.fd_inter_tensor,
139122
kv_seq_len=kv_seq_len,
140123
output_tensor=output_tensor,
141-
use_cuda_kernel=use_cuda_kernel,
142-
high_precision=high_precision,
143124
norm_output=norm_output,
144125
sm_scale=sm_scale,
145126
use_cuda_kernel=use_cuda_kernel,
@@ -160,7 +141,7 @@ def bloom_causal_lm_forward(
160141
) -> torch.Tensor:
161142

162143
hidden_states = bloom_model_forward(
163-
self.model,
144+
self.transformer,
164145
input_tokens_ids=input_tokens_ids,
165146
output_tensor=output_tensor,
166147
inputmetadata=inputmetadata,
@@ -173,11 +154,9 @@ def bloom_causal_lm_forward(
173154
return logits
174155

175156

176-
# TODO
177157
def bloom_block_forward(
178158
self: BloomBlock,
179159
hidden_states: torch.Tensor,
180-
alibi: torch.Tensor,
181160
block_tables: torch.Tensor,
182161
k_cache: torch.Tensor,
183162
v_cache: torch.Tensor,
@@ -204,17 +183,14 @@ def bloom_block_forward(
204183
residual = hidden_states
205184

206185
# Self attention
207-
attn_output, _ = self.self_attention(
186+
attn_output = self.self_attention(
208187
hidden_states=layernorm_output,
209-
residual=residual,
210-
alibi=alibi,
211-
hidden_states=hidden_states,
212188
block_tables=block_tables,
213189
k_cache=k_cache,
214190
v_cache=v_cache,
215191
is_prompts=is_prompts,
216-
is_verifier=is_verifier,
217-
tokens_to_verify=tokens_to_verify,
192+
# is_verifier=is_verifier,
193+
# tokens_to_verify=tokens_to_verify,
218194
sequence_lengths=sequence_lengths,
219195
fd_inter_tensor=fd_inter_tensor,
220196
kv_seq_len=kv_seq_len,
@@ -233,46 +209,50 @@ def bloom_block_forward(
233209
else:
234210
residual = attn_output
235211

236-
# MLP
237-
output = self.mlp(layernorm_output, residual) # including residuals
212+
print(f"[DEBUG] Show attn_output shape: {attn_output.shape}, \
213+
show residual shape: {residual.shape} \
214+
")
215+
216+
# MLP (including residuals)
217+
output = self.mlp(layernorm_output, residual)
238218

239219
return output
240220

241-
242-
# TODO
243-
class ColossalInferBloomAttention(BloomAttention):
221+
222+
class NopadBloomAttention(nn.Module):
244223
def __init__(
245224
self,
246-
config: BloomConfig,
225+
hidden_size: int,
226+
n_heads: int,
247227
attn_qproj_w: torch.Tensor = None,
248228
attn_kproj_w: torch.Tensor = None,
249229
attn_vproj_w: torch.Tensor = None,
250230
attn_oproj_w: torch.Tensor = None,
251231
):
252-
super().__init__(config)
253-
self.q_proj_weight = attn_qproj_w
254-
self.k_proj_weight = attn_kproj_w
255-
self.v_proj_weight = attn_vproj_w
256-
self.o_proj_weight = attn_oproj_w
257-
258-
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
259-
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
232+
super().__init__()
260233

261-
# garbage collection
262-
self.q_proj = None
263-
self.k_proj = None
264-
self.v_proj = None
234+
self.hidden_size = hidden_size
235+
self.num_heads = n_heads
236+
self.head_dim = self.hidden_size // self.num_heads
237+
self.o_proj_w = attn_oproj_w
238+
239+
qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
240+
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
265241

266242
@staticmethod
267-
def from_native_module(module: BloomAttention, *args, **kwargs) -> BloomAttention:
268-
config = module.config
269-
attn_qproj_w = module.q_proj.weight.transpose(0, 1)
270-
attn_kproj_w = module.k_proj.weight.transpose(0, 1)
271-
attn_vproj_w = module.v_proj.weight.transpose(0, 1)
272-
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
243+
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBloomAttention":
244+
hidden_size = module.hidden_size
245+
num_heads = module.num_heads
246+
q_proj_w, k_proj_w, v_proj_w = module.query_key_value.weight.view((3, hidden_size, hidden_size))
273247

274-
attn_layer = ColossalInferBloomAttention(
275-
config=config,
248+
attn_qproj_w = q_proj_w.transpose(0, 1)
249+
attn_kproj_w = k_proj_w.transpose(0, 1)
250+
attn_vproj_w = v_proj_w.transpose(0, 1)
251+
attn_oproj_w = module.dense.weight.transpose(0, 1)
252+
253+
attn_layer = NopadBloomAttention(
254+
hidden_size=hidden_size,
255+
n_heads=num_heads,
276256
attn_qproj_w=attn_qproj_w,
277257
attn_kproj_w=attn_kproj_w,
278258
attn_vproj_w=attn_vproj_w,
@@ -284,7 +264,6 @@ def from_native_module(module: BloomAttention, *args, **kwargs) -> BloomAttentio
284264
def forward(
285265
self,
286266
hidden_states: torch.Tensor,
287-
alibi: torch.Tensor,
288267
block_tables: torch.Tensor,
289268
k_cache: torch.Tensor,
290269
v_cache: torch.Tensor,
@@ -297,39 +276,38 @@ def forward(
297276
use_cuda_kernel: bool = True,
298277
cu_seqlens: torch.Tensor = None,
299278
high_precision: bool = False,
300-
):
279+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
301280

302281
token_nums = hidden_states.size(0)
303-
304282
hidden_states = hidden_states.expand(3, -1, -1)
305283
query_states, key_states, value_states = (
306284
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
307285
)
308286

309287
block_size = k_cache.size(-2)
310288

311-
if is_prompts: # Prefilling
312-
313-
# TODO context stage alibi flash_attn
314-
pass
315-
316-
else: # Decoding
317-
318-
# If alibi in this way, then next step is to softmax with matmul_result,
319-
# so do I need consider how to utilize the matmul_result
320-
matmul_result = alibi.baddbmm(
321-
batch1=query_states,
322-
batch2=key_states,
323-
beta=self.beta,
324-
alpha=self.inv_norm_factor,
289+
if is_prompts:
290+
# TODO(char-1ee) Integrate context stage flash attention with alibi encoding
291+
attn_output = context_attention_unpadded(
292+
q=query_states,
293+
k=key_states,
294+
v=value_states,
295+
k_cache=k_cache,
296+
v_cache=v_cache,
297+
context_lengths=sequence_lengths,
298+
block_size=block_size,
299+
block_tables=block_tables,
300+
output=output_tensor,
301+
alibi_slopes=fd_inter_tensor.alibi_slopes,
302+
max_seq_len=kv_seq_len,
303+
sm_scale=sm_scale,
325304
)
326-
327-
328-
attn_output = flash_decoding_attention_with_alibi(
305+
else:
306+
attn_output = flash_decoding_attention(
329307
q=query_states,
330308
k_cache=k_cache,
331309
v_cache=v_cache,
332-
alibi=alibi,
310+
alibi_slopes=fd_inter_tensor.alibi_slopes,
333311
kv_seq_len=sequence_lengths,
334312
block_tables=block_tables,
335313
block_size=block_size,
@@ -341,23 +319,30 @@ def forward(
341319
)
342320

343321
attn_output = attn_output.view(-1, self.hidden_size)
344-
attn_output = torch.mm(attn_output, self.o_proj_weight)
345-
322+
attn_output = torch.mm(attn_output, self.o_proj_w)
346323
return attn_output
347324

348325

349-
class ColossalInferBloomMLP(BloomMLP):
350-
def __init__(self, config: BloomConfig):
351-
super().__init__(config)
326+
class NopadBloomMLP(nn.Module):
327+
def __init__(self, hidden_size: int = 64, hidden_dropout: float = 0.0):
328+
super().__init__()
329+
self.hidden_size = hidden_size
330+
self.hidden_dropout = hidden_dropout
331+
self.dense_h_to_4h = nn.Linear(hidden_size, hidden_size * 4)
352332
self.gelu_impl = GeLUFunction.apply
333+
self.dense_4h_to_h = nn.Linear(hidden_size * 4, hidden_size)
334+
335+
self.dense_h_to_4h = self.dense_h_to_4h.half()
336+
self.dense_4h_to_h = self.dense_4h_to_h.half()
353337

354338
@staticmethod
355-
def from_native_method(module: BloomMLP, *args, **kwargs) -> BloomMLP:
356-
config = module.config
357-
mlp_layer = ColossalInferBloomMLP(config=config)
339+
def from_native_module(module: nn.Module, *args, **kwargs) -> "NopadBloomMLP":
340+
hidden_size = 64 # TODO: hyperparameters
341+
mlp_layer = NopadBloomMLP(hidden_size=hidden_size, hidden_dropout=module.hidden_dropout)
358342
return mlp_layer
359343

360344
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
345+
print(f"[DEBUG] Print shape of hidden_states: {hidden_states.shape}, and dtype is {hidden_states.dtype}")
361346
hidden_states = self.dense_h_to_4h(hidden_states)
362347
bias = torch.zero_like(hidden_states)
363348
hidden_states = self.gelu_impl(hidden_states, bias)
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from .glide_llama import GlideLlamaModelPolicy
22
from .nopadding_baichuan import NoPaddingBaichuanModelInferPolicy
33
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
4+
from .nopadding_bloom import NoPaddingBloomModelInferPolicy
45

56
model_policy_map = {
67
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
78
"nopadding_baichuan": NoPaddingBaichuanModelInferPolicy,
9+
"nopadding_bloom": NoPaddingBloomModelInferPolicy,
810
"glide_llama": GlideLlamaModelPolicy,
911
}
1012

1113
__all__ = [
1214
"NoPaddingLlamaModelInferPolicy",
1315
"NoPaddingBaichuanModelInferPolicy",
1416
"GlideLlamaModelPolicy",
15-
"BloomModelInferPolicy",
17+
"NoPaddingBloomModelInferPolicy",
1618
"model_polic_map",
1719
]

0 commit comments

Comments
 (0)