Skip to content

Commit 1fb7fd0

Browse files
author
rtp-llm
committed
feat - support sparse&robert embedding, support calc similarity
1 parent 65278e7 commit 1fb7fd0

19 files changed

+461
-177
lines changed

maga_transformer/async_decoder_engine/embedding/embedding_decoder_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ async def _generate_loop(self, streams: List[EmbeddingStream]) -> List[Embedding
4040
if all(finished):
4141
break
4242
await asyncio.sleep(0.001)
43-
4443
return [stream.output for stream in streams]
4544

4645
@torch.inference_mode()
@@ -57,7 +56,9 @@ def step(self):
5756
self.batch_input_.tp_sync()
5857
embedding_outputs = self.executor_.process(self.batch_input_)
5958
if g_parallel_info.tp_rank == 0:
60-
for idx, stream in enumerate(streams):
59+
# do synchronize before update result
60+
torch.cuda.synchronize()
61+
for idx, stream in enumerate(streams):
6162
stream.update(embedding_outputs[idx])
6263
self.report_metric(len(streams), t.cost_ms())
6364

maga_transformer/async_decoder_engine/embedding/embedding_model_executor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
77
from maga_transformer.ops.gpt_ops.gpt_op import GptOp
88
from maga_transformer.async_decoder_engine.embedding.embedding_stream import EmbeddingBatchedInput, EmbeddingOutput
9-
from maga_transformer.async_decoder_engine.embedding.post_process.post_process_factory import PostProcessFactory
9+
from maga_transformer.async_decoder_engine.embedding.post_process.post_process_module import PostProcessModule
1010

1111
class EmbeddingModelExecutor(object):
1212
def __init__(self, model: BaseModel, config: GptInitModelParameters):
@@ -15,10 +15,10 @@ def __init__(self, model: BaseModel, config: GptInitModelParameters):
1515
self.gpt_op_ = GptOp(self.config_, False)
1616
self.gpt_op_.set_weight(self.model_.weight)
1717

18-
self.post_process_module_ = PostProcessFactory.create_post_process_module(self.config_, self.model_.dtype)
18+
self.post_process_module_ = PostProcessModule(self.config_, self.model_.dtype, self.model_.tokenizer)
1919

20-
def _pre_process(self, batch_input: EmbeddingBatchedInput):
21-
combo_tokens_tensor = to_cuda(torch.IntTensor(batch_input.combo_tokens))
20+
def _pre_process(self, batch_input: EmbeddingBatchedInput):
21+
combo_tokens_tensor = to_cuda(torch.IntTensor(batch_input.combo_tokens))
2222
position_ids_tensor = to_cuda(self.model_.create_context_position_ids(batch_input.context_lengths_list))
2323
input_embeds = self.model_.async_input_word_embedding(combo_tokens_tensor, [])
2424
if self.model_.position_encoding is not None:
@@ -29,7 +29,7 @@ def _pre_process(self, batch_input: EmbeddingBatchedInput):
2929

3030
if self.model_.pre_decoder_layernorm is not None:
3131
input_embeds = self.model_.pre_decoder_layernorm(input_embeds)
32-
32+
3333
attention_mask = self.model_.create_context_decoder_mask(batch_input.context_lengths_list)
3434
return input_embeds, attention_mask, position_ids_tensor
3535

@@ -50,6 +50,6 @@ def process(self, batch_input: EmbeddingBatchedInput) -> List[EmbeddingOutput]:
5050
prefix_lengths=torch.IntTensor([0] * batch_input.batch_size),
5151
count_length=torch.BoolTensor([True]),
5252
max_prefix_length=torch.IntTensor([0]),
53-
lora_ids=torch.IntTensor([-1] * batch_input.batch_size))
54-
output = self.post_process_module_.process(batch_input, hidden_states, attention_mask)
53+
lora_ids=torch.IntTensor([-1] * batch_input.batch_size))
54+
output = self.post_process_module_.process(batch_input, hidden_states, attention_mask, batch_input.embedding_config)
5555
return output

maga_transformer/async_decoder_engine/embedding/embedding_scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def schedule(self) -> List[EmbeddingStream]:
3030
for stream in copy.copy(self.waiting_streams_):
3131
if total_len + stream.input.input_length > self.config_.max_context_batch_size * self.config_.max_seq_len:
3232
break
33+
# make sure embedding config is the same
34+
if len(new_streams) > 0 and stream.input.embedding_config != new_streams[0].input.embedding_config:
35+
break
3336
new_streams.append(stream)
3437
total_len += stream.input.input_length
3538

maga_transformer/async_decoder_engine/embedding/embedding_stream.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
import torch
2-
from typing import Any, List, Optional
2+
from typing import Any, List, Dict, Optional
33
from maga_transformer.utils.util import to_cuda, to_cpu
44

55
from maga_transformer.distribute.worker_info import g_parallel_info
6-
from maga_transformer.config.generate_config import GenerateConfig
6+
from maga_transformer.embedding.embedding_config import EmbeddingGenerateConfig
77
from maga_transformer.config.base_model_config import PyDanticModelBase
88

99
class EmbeddingInput(PyDanticModelBase):
1010
token_ids: List[int]
1111
token_type_ids: List[int]
1212
input_length: int
13-
generate_config: GenerateConfig
13+
embedding_config: EmbeddingGenerateConfig
1414

1515
class EmbeddingOutput(PyDanticModelBase):
1616
sentence_embedding: Optional[torch.Tensor] = None
17-
sparse_embedding: Optional[torch.Tensor] = None
17+
sparse_embedding: Optional[Dict[str, float]] = None
1818
colbert_embedding: Optional[torch.Tensor] = None
1919

2020
class EmbeddingStream(PyDanticModelBase):
@@ -26,10 +26,9 @@ class EmbeddingStream(PyDanticModelBase):
2626
def set_error(self, error: str):
2727
self.error_info = error
2828

29-
def update(self,
30-
embedding_output: EmbeddingOutput):
31-
self.finished = True
29+
def update(self, embedding_output: EmbeddingOutput):
3230
self.output = embedding_output
31+
self.finished = True
3332

3433
class EmbeddingBatchedInput(object):
3534
def __init__(self, nccl_op: Any) -> None:
@@ -41,6 +40,8 @@ def clear(self):
4140
self.context_lengths_list: List[int] = []
4241
self.combo_tokens: List[int] = []
4342
self.combo_token_type_ids: List[int] = []
43+
# no need to broadcast embedding config since only tp=0 will use it
44+
self.embedding_config = EmbeddingGenerateConfig()
4445

4546
def generate_model_input(self, streams: List[EmbeddingStream]):
4647
self.clear()
@@ -51,6 +52,7 @@ def generate_model_input(self, streams: List[EmbeddingStream]):
5152
self.combo_tokens.extend(stream.input.token_ids)
5253
self.combo_token_type_ids.extend(stream.input.token_type_ids)
5354
self.batch_size = len(self.context_lengths_list)
55+
self.embedding_config = streams[0].input.embedding_config
5456
self.token_num = len(self.combo_tokens)
5557

5658
def tp_sync(self):
@@ -64,13 +66,13 @@ def tp_sync(self):
6466
torch.cuda.current_stream().synchronize()
6567
shape_hints = shape_hints.cpu().numpy()
6668
assert shape_hints[0] == check_num and shape_hints[-1] == check_num2, 'check sum error'
67-
69+
6870
if g_parallel_info.tp_rank == 0:
6971
context_length_tensor = to_cuda(torch.IntTensor(self.context_lengths_list))
7072
combo_tokens_tensor = to_cuda(torch.IntTensor(self.combo_tokens))
7173
combo_token_type_ids_tensor = to_cuda(torch.IntTensor(self.combo_token_type_ids))
7274
else:
73-
self.batch_size = shape_hints[1]
75+
self.batch_size = shape_hints[1]
7476
self.token_num = shape_hints[2]
7577
context_length_tensor = torch.zeros([self.batch_size], dtype=torch.int32, device="cuda:0")
7678
combo_tokens_tensor = torch.zeros([self.token_num], dtype=torch.int32, device="cuda:0")
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
from numpy.typing import NDArray
3+
import numpy as np
4+
import torch
5+
from typing import List, Dict, Union, Optional
6+
7+
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
8+
9+
class ColBertEmbeddingModule(object):
10+
def __init__(self, hidden_size: int, state_dict: Dict[str, torch.Tensor], dtype: Union[str, torch.dtype]):
11+
self.colbert_linear = torch.nn.Linear(in_features=hidden_size, out_features=hidden_size)
12+
self.colbert_linear.load_state_dict(state_dict)
13+
self.colbert_linear = self.colbert_linear.to(dtype).cuda()
14+
15+
def _process_colbert_vecs(self, colbert_vecs: torch.Tensor, tokens_num: int):
16+
# delte the vectors of padding tokens
17+
return colbert_vecs[:tokens_num - 1] # we don't use the embedding of cls, so select tokens_num-1
18+
19+
def __call__(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, input_length: List[int], do_normalize: bool=True) -> List[torch.Tensor]:
20+
colbert_vecs = self.colbert_linear(hidden_states[:, 1:])
21+
colbert_vecs = colbert_vecs * attention_mask[:, 1:][:, :, None].float()
22+
if do_normalize:
23+
colbert_vecs = torch.nn.functional.normalize(colbert_vecs, dim=-1)
24+
all_colbert_vec = (list(map(self._process_colbert_vecs, colbert_vecs.cpu(), input_length)))
25+
return all_colbert_vec
26+
27+
def init_colbert_embedding_module(config: GptInitModelParameters, dtype: Union[str, torch.dtype]) -> Optional[ColBertEmbeddingModule]:
28+
colbert_linear_path = os.path.join(config.ckpt_path, 'colbert_linear.pt')
29+
if os.path.exists(colbert_linear_path):
30+
sparse_linear_dict = torch.load(colbert_linear_path, map_location='cpu')
31+
return ColBertEmbeddingModule(config.hidden_size, sparse_linear_dict, dtype)
32+
return None
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
import json
3+
from collections import OrderedDict
4+
import torch
5+
import torch.nn as nn
6+
from typing import List, Dict, Union
7+
8+
from sentence_transformers.util import import_from_string
9+
from sentence_transformers.models import Transformer, Normalize
10+
11+
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
12+
13+
class DenseEmbeddingModule(object):
14+
def __call__(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, input_length: List[int], do_normalize: bool) -> torch.Tensor:
15+
raise NotImplementedError()
16+
17+
def init_dense_embedding_module(config: GptInitModelParameters, dtype: Union[str, torch.dtype]) -> DenseEmbeddingModule:
18+
if os.path.exists(os.path.join(config.ckpt_path, 'modules.json')):
19+
dense_embedding_module = SentenceTransformerModule(config, dtype)
20+
else:
21+
dense_embedding_module = NormalModule(config.is_causal)
22+
return dense_embedding_module
23+
24+
class NormalModule(DenseEmbeddingModule):
25+
def __init__(self, is_casual: bool):
26+
self.is_casual = is_casual
27+
28+
def __call__(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, input_lengths: List[int], do_normalize: bool) -> torch.Tensor:
29+
batch_size = len(input_lengths)
30+
if self.is_casual:
31+
ts = torch.stack([hidden_states[idx][pos - 1] for idx, pos in enumerate(input_lengths)])
32+
else:
33+
ts = torch.stack([hidden_states[idx][0] for idx, pos in enumerate(input_lengths)])
34+
35+
if do_normalize:
36+
ts = torch.nn.functional.normalize(ts, dim=1)
37+
return ts
38+
39+
class SentenceTransformerModule(DenseEmbeddingModule):
40+
def __init__(self, config: GptInitModelParameters, dtype: Union[str, torch.dtype]):
41+
modules_config_path = os.path.join(config.ckpt_path, 'modules.json')
42+
assert os.path.exists(modules_config_path), "not found modules.json from sentence_transformer"
43+
with open(modules_config_path) as fIn:
44+
modules_config = json.load(fIn)
45+
modules: OrderedDict[str, nn.Module] = OrderedDict()
46+
for module_config in modules_config:
47+
module_class = import_from_string(module_config["type"])
48+
# For Transformer, don't load the full directory, rely on `transformers` instead
49+
# But, do load the config file first.
50+
if module_class == Transformer and module_config["path"] == "":
51+
pass
52+
else:
53+
# Normalize does not require any files to be loaded
54+
if module_class == Normalize:
55+
module_path = None
56+
else:
57+
module_path = os.path.join(config.ckpt_path, module_config["path"])
58+
module = module_class.load(module_path)
59+
modules[module_config["name"]] = module
60+
self.model = nn.Sequential(modules).cuda().to(dtype)
61+
62+
def __call__(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, input_length: List[int], do_normalize: bool) -> torch.Tensor:
63+
input = {
64+
"token_embeddings": hidden_states,
65+
"attention_mask": attention_mask
66+
}
67+
return self.model(input)['sentence_embedding']

maga_transformer/async_decoder_engine/embedding/post_process/post_process_factory.py

Lines changed: 0 additions & 24 deletions
This file was deleted.
Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,77 @@
11
import torch
2-
from typing import List
3-
from maga_transformer.async_decoder_engine.embedding.embedding_stream import EmbeddingBatchedInput, EmbeddingOutput
2+
from typing import List, Union, Optional, Dict, Tuple
3+
from torch.nn.utils.rnn import pad_sequence
4+
from transformers import PreTrainedTokenizerBase
45

6+
from maga_transformer.utils.util import to_cuda
7+
from maga_transformer.embedding.embedding_config import EmbeddingGenerateConfig, EmbeddingType
8+
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
9+
from maga_transformer.async_decoder_engine.embedding.post_process.dense_embedding_module import init_dense_embedding_module
10+
from maga_transformer.async_decoder_engine.embedding.post_process.sparse_emebdding_module import init_sparse_embedding_module
11+
from maga_transformer.async_decoder_engine.embedding.post_process.colbert_embedding_module import init_colbert_embedding_module
12+
from maga_transformer.async_decoder_engine.embedding.embedding_stream import EmbeddingBatchedInput, EmbeddingOutput
513

614
class PostProcessModule(object):
7-
def process(self, batch_query: EmbeddingBatchedInput, hidde_states: torch.Tensor, attention_mask: torch.Tensor) -> List[EmbeddingOutput]:
8-
raise NotImplementedError()
15+
def __init__(self, config: GptInitModelParameters, dtype: Union[torch.dtype, str], tokenizer: PreTrainedTokenizerBase):
16+
self.config_ = config
17+
self.dtype_ = dtype
18+
self.tokenizer_ = tokenizer
19+
self.pad_token_id_ = self.tokenizer_.pad_token_id if self.tokenizer_.pad_token_id is not None else 0
20+
self.dense_embedding_module_ = init_dense_embedding_module(config, dtype)
21+
self.sparse_embedding_module_ = init_sparse_embedding_module(config, tokenizer, dtype)
22+
self.colbert_embedding_module_ = init_colbert_embedding_module(config, dtype)
23+
24+
25+
# attention_mask from [batch, max_seq, max_seq] to [batch, max_seq]
26+
# hidden_states/input_ids from [combo_length, hidden_states] to [batch, max_seq, hidden_states]
27+
def _reorder_input(self, batch_input: EmbeddingBatchedInput, hidde_states: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
28+
sliced_hidden_states: List[torch.Tensor] = []
29+
sliced_input_ids: List[torch.Tensor] = []
30+
attention_mask_indexs: List[int] = []
31+
hidden_bias = 0
32+
mask_bias = 0
33+
for input_length in batch_input.context_lengths_list:
34+
sliced_hidden_states.append(hidde_states[hidden_bias: hidden_bias + input_length])
35+
sliced_input_ids.append(torch.IntTensor(batch_input.combo_tokens[hidden_bias: hidden_bias + input_length]))
36+
attention_mask_indexs.append(mask_bias + input_length - 1)
37+
mask_bias += attention_mask.shape[1]
38+
hidden_bias += input_length
39+
batched_hidden_states = pad_sequence(sliced_hidden_states, batch_first=True, padding_value=self.pad_token_id_)
40+
batched_input_ids = pad_sequence(sliced_input_ids, batch_first=True, padding_value=self.pad_token_id_)
41+
batched_attention_mask = attention_mask.reshape(-1, attention_mask.shape[2])[attention_mask_indexs].contiguous()
42+
return batched_input_ids, batched_hidden_states, batched_attention_mask
43+
44+
def _set_outputs(self, outputs: List[EmbeddingOutput],
45+
dense_embedding: Optional[torch.Tensor],
46+
sparse_embedding: Optional[List[Dict[str, float]]],
47+
colbert_embedding: Optional[List[torch.Tensor]]):
48+
if dense_embedding is not None:
49+
for index, dense in enumerate(dense_embedding):
50+
outputs[index].sentence_embedding = dense
51+
if sparse_embedding is not None:
52+
for index, sparse in enumerate(sparse_embedding):
53+
outputs[index].sparse_embedding = sparse
54+
if colbert_embedding is not None:
55+
for index, colbert in enumerate(colbert_embedding):
56+
outputs[index].colbert_embedding = colbert
57+
58+
def process(self, batch_input: EmbeddingBatchedInput, hidde_states: torch.Tensor, attention_mask: torch.Tensor, embedding_config: EmbeddingGenerateConfig) -> List[EmbeddingOutput]:
59+
outputs = [EmbeddingOutput() for _ in range(batch_input.batch_size)]
60+
batch_input_ids, batch_hidden_states, batch_attention_mask = self._reorder_input(batch_input, hidde_states, attention_mask)
61+
dense_embedding = None
62+
sprase_embedding = None
63+
colbert_embedding = None
64+
if embedding_config.type == EmbeddingType.DENSE:
65+
dense_embedding = self.dense_embedding_module_(hidden_states=batch_hidden_states, attention_mask=batch_attention_mask,
66+
input_length=batch_input.context_lengths_list,
67+
do_normalize=embedding_config.do_normalize)
68+
if embedding_config.type == EmbeddingType.SPARSE:
69+
if self.sparse_embedding_module_ is None:
70+
raise Exception("module not support sparse embedding")
71+
sprase_embedding = self.sparse_embedding_module_(batch_input_ids, batch_hidden_states)
72+
if embedding_config.type == EmbeddingType.COLBERT:
73+
if self.colbert_embedding_module_ == None:
74+
raise Exception("module not support colbert embedding")
75+
colbert_embedding = self.colbert_embedding_module_(batch_hidden_states, batch_attention_mask, batch_input.context_lengths_list, do_normalize=embedding_config.do_normalize)
76+
self._set_outputs(outputs, dense_embedding, sprase_embedding, colbert_embedding)
77+
return outputs

0 commit comments

Comments
 (0)