Skip to content

Commit 46e5551

Browse files
authored
Implement InTokens data stream for ChatGLM (#6701)
* fix styles * to list * benchmarks * ready for PR
1 parent 262b23a commit 46e5551

File tree

7 files changed

+96
-57
lines changed

7 files changed

+96
-57
lines changed

examples/benchmark/peft/paddle/benchmark.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,23 +142,21 @@ def preprocess_function_chatglm(example, max_src_length=256, max_tgt_length=384,
142142
model_inputs["input_ids"] = model_inputs["input_ids"][:-1]
143143
model_inputs["labels"] = model_inputs["labels"][1:]
144144

145-
context_length = model_inputs["input_ids"].index(tokenizer.bos_token_id)
146-
seq_length = len(model_inputs["input_ids"])
147-
position_ids = np.arange(seq_length, dtype=np.int64)
148-
block_position_ids = np.concatenate(
149-
[
150-
np.zeros(context_length, dtype=np.int64),
151-
np.arange(1, seq_length - context_length + 1, dtype=np.int64),
152-
]
153-
)
154-
model_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
155-
# attention mask
156145
if intokens:
157-
attention_mask = np.ones((seq_length, seq_length))
158-
attention_mask = np.tril(attention_mask)
146+
context_length = model_inputs["input_ids"].index(tokenizer.bos_token_id)
147+
seq_length = len(model_inputs["input_ids"])
148+
position_ids = np.arange(seq_length, dtype=np.int64)
149+
block_position_ids = np.concatenate(
150+
[
151+
np.zeros(context_length, dtype=np.int64),
152+
np.arange(1, seq_length - context_length + 1, dtype=np.int64),
153+
]
154+
)
155+
model_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
156+
attention_mask = np.tri(seq_length, seq_length, dtype=bool)
159157
attention_mask[:, :context_length] = 1
160-
attention_mask = (attention_mask < 0.5).astype("int64")
161158
model_inputs["attention_mask"] = attention_mask
159+
162160
return model_inputs
163161

164162
def preprocess_function_bloom(example, max_src_length=256, max_tgt_length=384, intokens=False):

llm/causallm/data.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens
8888
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
8989
source_length = len(tokenized_source["input_ids"])
9090
labels = [-100] * source_length + input_ids[source_length:]
91-
# shift labels
91+
# shift input_ids and labels
9292
input_ids, labels = input_ids[:-1], labels[1:]
9393
features = {
9494
"input_ids": input_ids,
@@ -97,7 +97,7 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens
9797
seq_length = len(input_ids)
9898
if intokens:
9999
features["position_ids"] = list(range(seq_length))
100-
features["attention_mask"] = np.tril(np.ones((seq_length, seq_length), dtype="bool"))
100+
features["attention_mask"] = np.tri((seq_length, seq_length), dtype=bool)
101101

102102
return features
103103

@@ -115,15 +115,28 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intoken
115115
else:
116116
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
117117
bos_position = len(tokenized_source["input_ids"]) - 1
118-
119-
attention_mask = np.tri(len(input_ids), len(input_ids))
120-
attention_mask[:, :bos_position] = 1
121-
attention_mask = attention_mask[None, :, :]
122-
123118
labels = [-100] * bos_position + input_ids[bos_position:]
124-
125-
# shift labels
119+
# shift input_ids and labels
126120
input_ids, labels = input_ids[:-1], labels[1:]
127-
attention_mask = attention_mask[..., :-1, :-1]
121+
features = {
122+
"input_ids": input_ids,
123+
"labels": labels,
124+
}
128125

129-
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
126+
if intokens:
127+
seq_length = len(input_ids)
128+
# attention_mask
129+
attention_mask = np.tri(seq_length, seq_length, dtype=bool)
130+
attention_mask[:, :bos_position] = 1
131+
features["attention_mask"] = attention_mask
132+
# 2d position_ids
133+
position_ids = np.arange(seq_length, dtype=np.int64)
134+
block_position_ids = np.concatenate(
135+
[
136+
np.zeros(bos_position, dtype=np.int64),
137+
np.arange(1, seq_length - bos_position + 1, dtype=np.int64),
138+
]
139+
)
140+
features["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
141+
142+
return features

llm/causallm/finetune_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def main():
118118
train_ds, dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["train", "dev"])
119119
trans_func = partial(get_convert_example(model), tokenizer=tokenizer, data_args=data_args)
120120
if data_args.intokens:
121-
if model.base_model_prefix not in ["llama", "bloom"]:
122-
raise NotImplementedError("InTokens data stream is only implemented for LLaMA Bloom so far.")
121+
if model.base_model_prefix not in ["llama", "bloom", "chatglm"]:
122+
raise NotImplementedError("InTokens data stream is only implemented for LLaMA, Bloom and ChatGLM so far.")
123123
train_ds = train_ds.map(partial(trans_func, is_test=False, intokens=data_args.intokens))
124124
eval_intokens = data_args.intokens
125125
if data_args.intokens and data_args.eval_with_do_generation:

paddlenlp/datasets/intokens_dataset.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,16 @@ def _pad_batch_records(cls, batch_records):
4646
# If attention_mask is not given, assume it's causal mask
4747
attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool)))
4848
batched_features["attention_mask"].append(attention_mask)
49-
# TODO: to adapt to chatglm position_2d
5049
# NOTE: position_ids is optional and not required by every model
50+
# We append instead of extend here to accomodate 2D position ids
5151
if "position_ids" in record:
52-
batched_features["position_ids"].extend(record["position_ids"])
52+
batched_features["position_ids"].append(record["position_ids"])
5353
block_attention_mask = block_diag(*batched_features["attention_mask"])
5454
# convert to 3-D [batch_size(1), seq_length, seq_length]
5555
batched_features["attention_mask"] = np.expand_dims(block_attention_mask, axis=0)
56-
# batched_features["input_ids"] = np.array(batched_features["input_ids"], dtype=np.int64)
57-
# batched_features["labels"] = np.array(batched_features["labels"], dtype=np.int64)
58-
# if "position_ids" in record:
59-
# batched_features["position_ids"] = np.array(batched_features["position_ids"], dtype=np.int64)
56+
if "position_ids" in batched_features:
57+
# Accomodate both 1D and 2D position ids
58+
batched_features["position_ids"] = np.concatenate(batched_features["position_ids"], axis=-1).tolist()
6059
return batched_features
6160

6261

paddlenlp/transformers/chatglm/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,8 +581,8 @@ class ChatGLMPretrainedModel(PretrainedModel):
581581
model_config_file = CONFIG_NAME
582582
resource_files_names = {"model_state": "model_state.pdparams"}
583583
pretrained_resource_files_map = CHATGLM_PRETRAINED_RESOURCE_FILES_MAP
584-
_keys_to_ignore_on_load_missing = [r"transformer.layers.*.attention.rotary_embeddings.inv_freq"]
585-
_keys_to_ignore_on_load_unexpected = [r"transformer.layers.*.attention.rotary_emb.inv_freq"]
584+
_keys_to_ignore_on_load_missing = [r"transformer.rotary_embeddings.inv_freq", r"lm_head.decoder_weight"]
585+
_keys_to_ignore_on_load_unexpected = [r"transformer.rotary_emb.inv_freq"]
586586

587587
def init_weights(self, layer):
588588
"""Initialization hook"""

paddlenlp/transformers/conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,7 @@ def convert_tensor_parallel(
10251025
if state_dict is None:
10261026
with device_guard("cpu"):
10271027
state_dict = paddle.load(weight_file, return_numpy=False)
1028-
logger.info("starting convert orignal state_dict to tensor parallel state_dict.")
1028+
logger.info("Starting to convert orignal state_dict to tensor parallel state_dict.")
10291029

10301030
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error)
10311031

@@ -1035,7 +1035,7 @@ def convert_tensor_parallel(
10351035
for name, action in name_action_mappings.items():
10361036
if name not in state_dict:
10371037
if not ignore_error:
1038-
logger.warning(f"key<{name}> not in the model state weight file.")
1038+
logger.warning(f"Key <{name}> not in the model state weight file.")
10391039
continue
10401040
tensor = state_dict.pop(name)
10411041
new_tensor = action(tensor)

tests/dataset/test_intokens.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class InTokensTestCommon:
3636
expected_output = {
3737
"input_ids": [1, 29871, 30429, 1, 29871, 30429, 2, 1, 29871, 31427, 1, 29871, 31427, 2],
3838
"labels": [-100, -100, -100, 1, 29871, 30429, 2, -100, -100, -100, 1, 29871, 31427, 2],
39-
"position_ids": [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6],
39+
"position_ids": np.array([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]),
4040
"attention_mask": np.array(
4141
[
4242
[
@@ -57,25 +57,34 @@ class InTokensTestCommon:
5757
]
5858
]
5959
),
60+
"position_ids_2d": [[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]],
6061
}
6162

62-
def preprocess_fn(self, example, max_src_length=3, max_tgt_length=3):
63+
def preprocess_fn(
64+
self,
65+
example,
66+
max_src_length=3,
67+
max_tgt_length=3,
68+
return_position_ids=True,
69+
position_ids_2d=False,
70+
return_attention_mask=True,
71+
):
6372
inputs = example["sentence"][:2]
6473
model_inputs = self.tokenizer(inputs, max_length=max_src_length, truncation=True, return_attention_mask=False)
6574
labels_input_ids = model_inputs["input_ids"] + [self.tokenizer.eos_token_id]
6675
model_inputs["labels"] = [-100] * len(model_inputs["input_ids"]) + labels_input_ids
6776
model_inputs["input_ids"] = model_inputs["input_ids"] + labels_input_ids
6877
seq_length = len(model_inputs["input_ids"])
69-
model_inputs["position_ids"] = list(range(seq_length))
70-
model_inputs["attention_mask"] = np.tril(np.ones([seq_length, seq_length]))
71-
return model_inputs
72-
73-
def preprocess_fn_input_labels_only(self, example, max_src_length=3, max_tgt_length=3):
74-
inputs = example["sentence"][:2]
75-
model_inputs = self.tokenizer(inputs, max_length=max_src_length, truncation=True, return_attention_mask=False)
76-
labels_input_ids = model_inputs["input_ids"] + [self.tokenizer.eos_token_id]
77-
model_inputs["labels"] = [-100] * len(model_inputs["input_ids"]) + labels_input_ids
78-
model_inputs["input_ids"] = model_inputs["input_ids"] + labels_input_ids
78+
if return_position_ids:
79+
if position_ids_2d:
80+
position_ids = np.arange(seq_length, dtype=np.int64)
81+
# fake block_position_ids with wrong values but correct shape
82+
block_position_ids = np.arange(seq_length, dtype=np.int64)
83+
model_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
84+
else:
85+
model_inputs["position_ids"] = list(range(seq_length))
86+
if return_attention_mask:
87+
model_inputs["attention_mask"] = np.tril(np.ones([seq_length, seq_length]))
7988
return model_inputs
8089

8190

@@ -89,10 +98,14 @@ def setUpClass(cls):
8998
data_files=[os.path.join(fixture_path, "tnews", "train.json")],
9099
lazy=False,
91100
)
92-
copy_train_ids = copy.deepcopy(cls.train_ds)
101+
copy_dataset_1 = copy.deepcopy(cls.train_ds)
102+
copy_dataset_2 = copy.deepcopy(cls.train_ds)
93103
cls.dataset = cls.train_ds.map(lambda example: cls.preprocess_fn(cls, example))
94-
cls.dataset_input_labels_only = copy_train_ids.map(
95-
lambda example: cls.preprocess_fn_input_labels_only(cls, example)
104+
cls.dataset_position_2d = copy_dataset_1.map(
105+
lambda example: cls.preprocess_fn(cls, example, position_ids_2d=True)
106+
)
107+
cls.dataset_input_labels_only = copy_dataset_2.map(
108+
lambda example: cls.preprocess_fn(cls, example, return_position_ids=False, return_attention_mask=False)
96109
)
97110

98111
def test_long_max_length(self):
@@ -111,8 +124,8 @@ def test_long_max_length(self):
111124
def test_short_max_length(self):
112125
inData = InTokensMapDataset(self.dataset, self.tokenizer, max_length=16)
113126
self.assertEqual(inData[0]["input_ids"], self.expected_output["input_ids"])
114-
self.assertEqual(inData[0]["position_ids"], self.expected_output["position_ids"])
115127
self.assertEqual(inData[0]["labels"], self.expected_output["labels"])
128+
self.assertTrue((inData[0]["position_ids"] == self.expected_output["position_ids"]).all())
116129
self.assertTrue((inData[0]["attention_mask"] == self.expected_output["attention_mask"]).all())
117130

118131
inData_input_labels_only = InTokensMapDataset(self.dataset_input_labels_only, self.tokenizer, max_length=16)
@@ -122,6 +135,10 @@ def test_short_max_length(self):
122135
(inData_input_labels_only[0]["attention_mask"] == self.expected_output["attention_mask"]).all()
123136
)
124137

138+
def test_2d_position_id(self):
139+
inData_2d = InTokensMapDataset(self.dataset_position_2d, self.tokenizer, max_length=16)
140+
self.assertTrue((inData_2d[0]["position_ids"] == self.expected_output["position_ids_2d"]).all())
141+
125142
def test_missing_data(self):
126143
orginal_input_ids = [item["input_ids"] for item in self.dataset]
127144
orginal_input_ids = [sum(orginal_input_ids, [])]
@@ -138,10 +155,14 @@ def setUpClass(cls):
138155
cls.train_ds = load_dataset(
139156
read_local_dataset, path=os.path.join(fixture_path, "tnews", "train.json"), lazy=True
140157
)
141-
copy_train_ids = copy.deepcopy(cls.train_ds)
158+
copy_dataset_1 = copy.deepcopy(cls.train_ds)
159+
copy_dataset_2 = copy.deepcopy(cls.train_ds)
142160
cls.dataset = cls.train_ds.map(lambda example: cls.preprocess_fn(cls, example))
143-
cls.dataset_input_labels_only = copy_train_ids.map(
144-
lambda example: cls.preprocess_fn_input_labels_only(cls, example)
161+
cls.dataset_position_2d = copy_dataset_1.map(
162+
lambda example: cls.preprocess_fn(cls, example, position_ids_2d=True)
163+
)
164+
cls.dataset_input_labels_only = copy_dataset_2.map(
165+
lambda example: cls.preprocess_fn(cls, example, return_position_ids=False, return_attention_mask=False)
145166
)
146167

147168
def test_long_max_length(self):
@@ -174,8 +195,8 @@ def test_short_max_length(self):
174195
example.append(item)
175196
break
176197
self.assertEqual(example[0]["input_ids"], self.expected_output["input_ids"])
177-
self.assertEqual(example[0]["position_ids"], self.expected_output["position_ids"])
178198
self.assertEqual(example[0]["labels"], self.expected_output["labels"])
199+
self.assertTrue((example[0]["position_ids"] == self.expected_output["position_ids"]).all())
179200
self.assertTrue((example[0]["attention_mask"] == self.expected_output["attention_mask"]).all())
180201

181202
inData_input_labels_only = InTokensIterableDataset(
@@ -189,6 +210,14 @@ def test_short_max_length(self):
189210
self.assertEqual(example[0]["labels"], self.expected_output["labels"])
190211
self.assertTrue((example[0]["attention_mask"] == self.expected_output["attention_mask"]).all())
191212

213+
def test_2d_position_id(self):
214+
inData_2d = InTokensIterableDataset(self.dataset_position_2d, self.tokenizer, max_length=16)
215+
example = []
216+
for item in inData_2d:
217+
example.append(item)
218+
break
219+
self.assertTrue((example[0]["position_ids"] == self.expected_output["position_ids_2d"]).all())
220+
192221
def test_missing_data(self):
193222
orginal_input_ids = [item["input_ids"] for item in self.dataset]
194223
orginal_input_ids = [sum(orginal_input_ids, [])]

0 commit comments

Comments
 (0)