From a49574f568f21b4e30c3055638c5fbb4eb5eae48 Mon Sep 17 00:00:00 2001 From: Lusfie <15176063690@163.com> Date: Wed, 23 Oct 2024 13:54:52 +0800 Subject: [PATCH 1/2] update: llama3 --- .gitignore | 4 + projects/Llama/utils/prepare_alpaca.py | 3 +- projects/Llama3/README.md | 60 ++ projects/Llama3/adapter/adapter_config.py | 63 ++ projects/Llama3/adapter/adapter_model.py | 730 ++++++++++++++++++++++ projects/Llama3/adapter/adapter_sft.py | 97 +++ projects/Llama3/adapter/train_net.py | 115 ++++ projects/Llama3/configs/llama_config.py | 61 ++ projects/Llama3/configs/llama_sft.py | 97 +++ projects/Llama3/dataset.py | 19 + projects/Llama3/llama.py | 647 +++++++++++++++++++ projects/Llama3/pipeline.py | 128 ++++ projects/Llama3/tokenizer.py | 112 ++++ projects/Llama3/train_net.py | 102 +++ projects/Llama3/utils/eval_adapter.py | 177 ++++++ projects/Llama3/utils/llama_loader.py | 110 ++++ projects/Llama3/utils/prepare_alpaca.py | 162 +++++ 17 files changed, 2686 insertions(+), 1 deletion(-) create mode 100644 projects/Llama3/README.md create mode 100644 projects/Llama3/adapter/adapter_config.py create mode 100644 projects/Llama3/adapter/adapter_model.py create mode 100644 projects/Llama3/adapter/adapter_sft.py create mode 100644 projects/Llama3/adapter/train_net.py create mode 100644 projects/Llama3/configs/llama_config.py create mode 100644 projects/Llama3/configs/llama_sft.py create mode 100644 projects/Llama3/dataset.py create mode 100644 projects/Llama3/llama.py create mode 100644 projects/Llama3/pipeline.py create mode 100644 projects/Llama3/tokenizer.py create mode 100644 projects/Llama3/train_net.py create mode 100644 projects/Llama3/utils/eval_adapter.py create mode 100644 projects/Llama3/utils/llama_loader.py create mode 100644 projects/Llama3/utils/prepare_alpaca.py diff --git a/.gitignore b/.gitignore index ab28bbee3..a951be864 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# data file +alpaca_data/ +libai/version.py +sft_result # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/projects/Llama/utils/prepare_alpaca.py b/projects/Llama/utils/prepare_alpaca.py index c21f505fb..047718004 100644 --- a/projects/Llama/utils/prepare_alpaca.py +++ b/projects/Llama/utils/prepare_alpaca.py @@ -114,7 +114,8 @@ def prepare_sample(example: dict, tokenizer, max_length: int) -> dict: prompt = tokenizer.tokenize(full_prompt, add_bos=True, add_eos=False, device="cpu")[0] example = tokenizer.tokenize( - full_prompt_and_response, add_bos=True, add_eos=True, device="cpu" + full_prompt_and_response, add_bos=True, add_eos=True, device=None, + # device="cpu" )[0] padding = max_length - example.shape[0] diff --git a/projects/Llama3/README.md b/projects/Llama3/README.md new file mode 100644 index 000000000..0b9673d66 --- /dev/null +++ b/projects/Llama3/README.md @@ -0,0 +1,60 @@ +# Llama3 + +Reproduce Llama3 with OneFlow, which effect are equivalent to HuggingFace's [Llama3](https://huggingface.co/docs/transformers/main/en/model_doc/llama3#overview). + +## Introduce +The Llama3 Supervised FineTuning project can support 3D parallel. + +## FineTuning Llama3 +FineTuning Llama3 on 8 GPUs using parallelism. + +### 1. Prepare the alpaca dataset + +> set the parameters in `projects/Llama3/utils/prepare_alpaca.py` for prepare the datasets, such as `destination_path` and `checkpoint_dir`. + +> Get the alpaca dataset files by running: +```python3 +# path/to/libai +python projects/Llama3/utils/prepare_alpaca.py +``` + +### 2. Prepare your finetuning config file + +> set the finetuning parameters in `projects/Llama3/configs/llama_sft.py`, such as `dataset_path` and `pretrained_model_path`. + +### 3. Run the following code to start SFT +```bash +# full finetune +bash tools/train.sh projects/Llama3/train_net.py projects/Llama3/configs/llama_sft.py 8 + +# adapter finetune +bash tools/train.sh projects/Llama3/adapter/train_net.py projects/Llama3/adapter/adapter_sft.py 8 +``` + +## Evaluate + +> set the eval parameters in `/data/home/xiezipeng/libai/projects/Llama3/utils/eval_adapter.py`, and running: +```python3 +python projects/Llama3/utils/eval_adapter.py +``` + +## Llama3 Inference + +- Prepare the Llama3 checkpoint. +- Adjust the parameters in the `projects/Llama3/pipeline.py`, and running: +```bash +bash tools/infer.sh projects/Llama3/pipeline.py 8 +``` + +## npu/xpu example + +- npu +```bash +python projects/Llama3/pipeline.py --device=npu --mode=huggingface --model_path /your/model/path +``` + +- xpu +```bash +python projects/Llama3/pipeline.py --device=xpu --mode=huggingface --model_path /your/model/path +``` + diff --git a/projects/Llama3/adapter/adapter_config.py b/projects/Llama3/adapter/adapter_config.py new file mode 100644 index 000000000..320ade218 --- /dev/null +++ b/projects/Llama3/adapter/adapter_config.py @@ -0,0 +1,63 @@ +from omegaconf import DictConfig, OmegaConf + +from configs.common.train import train # noqa +from libai.config import LazyCall +from projects.Llama3.adapter.adapter_model import LlamaForCausalLM +from projects.Llama3.tokenizer import LlamaTokenizer + +cfg = dict( + # Model + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=11008, + max_position_embeddings=2048, + num_attention_heads=32, + hidden_layers=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=False, + vocab_size=32000, + use_scaled_init_for_output_weights=False, + scale_mask_softmax_fusion=False, + amp_enabled=True, + # Inference + is_encoder_decoder=False, + max_length=256, + min_length=0, + do_sample=False, + early_stopping=False, + num_beams=1, + num_beam_groups=1, + diversity_penalty=0.0, + temperature=0.9, + top_k=50, + top_p=0.6, + typical_p=1.0, + repetition_penalty=1.0, + length_penalty=1.0, + no_repeat_ngram_size=0, + encoder_no_repeat_ngram_size=0, + num_return_sequences=1, + chunk_size_feed_forward=0, + output_scores=False, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + pad_token_id=0, + # adapter + adapter_len=10, + adapter_layer=30, + # train + pretrained_model_path="meta-llama/Llama-3-8B/", +) + +cfg = DictConfig(cfg) + +model = LazyCall(LlamaForCausalLM)(cfg=cfg) +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(LlamaTokenizer)( + pretrained_model_path="Llama-3-8B/tokenizer.model" +) diff --git a/projects/Llama3/adapter/adapter_model.py b/projects/Llama3/adapter/adapter_model.py new file mode 100644 index 000000000..78a72b4fe --- /dev/null +++ b/projects/Llama3/adapter/adapter_model.py @@ -0,0 +1,730 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import oneflow as flow +import oneflow.nn.functional as F +from oneflow import nn + +from libai.config import configurable +from libai.inference.generator.generation_utils import Generator +from libai.layers import Embedding, Linear, RMSLayerNorm, VocabEmbedding +from libai.layers.attention import AttnMaskType +from libai.models.utils import init_method_normal, scaled_init_method_normal +from libai.utils import distributed as dist + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return flow.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + + def forward(self, x, seq_len=None, cos_cached=None, sin_cached=None): + if seq_len > self.max_position_embeddings: + raise ValueError( + f"The maximum supported length is {self.max_position_embeddings}, " + f"and the current length is{seq_len}." + ) + + return ( + cos_cached[:seq_len].to_global(placement=x.placement), + sin_cached[:seq_len].to_global(placement=x.placement), + ) + + +class MLP(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + *, + layer_idx=0, + ): + super().__init__() + + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.gate_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.up_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.down_proj = Linear( + intermediate_size, + hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.activation_func = nn.SiLU() + + def forward(self, hidden_states): + gate_out = self.activation_func(self.gate_proj(hidden_states)) + up_out = self.up_proj(hidden_states) + output = self.down_proj(gate_out * up_out) + return output + + +class MultiheadAttention(nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + max_position_embeddings, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.num_heads = num_attention_heads + self.head_size = hidden_size // num_attention_heads + self.attn_mask_type = attn_mask_type + + self.norm_factor = 1.0 / math.sqrt(float(self.head_size)) + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.query_key_value = Linear( + self.hidden_size, + self.hidden_size * 3, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.o_proj = Linear( + self.hidden_size, + self.hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.coeff = None + + rotary_dim = self.head_size + self.rotary_embed = RotaryEmbedding( + dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + ) + + self.gate = flow.nn.Parameter( + flow.zeros( + 1, + self.num_heads, + 1, + 1, + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + ) + + def forward( + self, + hidden_states: flow.Tensor, + encoder_states: flow.Tensor = None, + attention_mask: flow.Tensor = None, + position_ids=None, + past_key_value: Tuple[flow.Tensor, flow.Tensor] = None, + cos_cached: flow.Tensor = None, + sin_cached: flow.Tensor = None, + use_cache: bool = False, + adapter=None, + ): + if encoder_states is not None: + encoder_states = encoder_states.to_global(placement=hidden_states.placement) + + if attention_mask is not None: + attention_mask = attention_mask.to_global(placement=hidden_states.placement) + + if adapter is not None: + adapter = adapter.to_global(placement=hidden_states.placement) + + bsz, tgt_len = hidden_states.size()[:2] + + query_key_value = self.query_key_value(hidden_states) + query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size) + query_key_value = query_key_value.permute( + 0, 2, 1, 3 + ) # [bsz, num_heads, src_len, 3 * head_size] + query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1) + + kv_seq_len = key.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_embed( + value, seq_len=kv_seq_len, cos_cached=cos_cached, sin_cached=sin_cached + ) + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + + # [1, adapter_len, 4096] + if adapter is not None: + adapter_len = adapter.shape[1] + adapter_qkv = self.query_key_value(adapter) + adapter_qkv = adapter_qkv.view(1, -1, self.num_heads, 3 * self.head_size) + adapter_qkv = adapter_qkv.permute(0, 2, 1, 3) # [1, num_heads, src_len, 3 * head_size] + _, adapter_key, adapter_value = flow.chunk(adapter_qkv, chunks=3, dim=-1) + adapter_key = adapter_key.repeat(bsz, 1, 1, 1) + adapter_value = adapter_value.repeat(bsz, 1, 1, 1) + key = flow.cat([adapter_key, key], dim=2) + value = flow.cat([adapter_value, value], dim=2) + extra_mask = flow.zeros( + bsz, + 1, + tgt_len, + adapter_len, + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=attention_mask.placement, + ) + attention_mask = flow.cat([extra_mask, attention_mask], dim=-1) + + if past_key_value is not None: + past_key, past_value = past_key_value + key = flow.cat((past_key.type_as(key), key), dim=2) + value = flow.cat((past_value.type_as(value), value), dim=2) + + # query, key, value: [S(0), S(1)], shape: [bsz, num_heads, seq_length, head_size] + if use_cache: + past_key_value = (key, value) + + # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)] + attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor) + attention_weights = attention_scores + attention_mask + + if adapter is not None: + attention_weights = flow.cat( + [ + self.gate.tanh().half() + * F.softmax(attention_weights[:, :, :, :adapter_len].float(), dim=-1).to( + query.dtype + ), + F.softmax(attention_weights[:, :, :, adapter_len:].float(), dim=-1).to( + query.dtype + ), + ], + dim=-1, + ) + else: + attention_weights = flow.softmax(attention_weights, dim=-1) + # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)] + context = flow.matmul(attention_weights, value) + + # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size] + context = context.transpose(1, 2) + output = self.o_proj(context.flatten(2)) + + if use_cache: + output = (output, past_key_value) + + return output + + +class CasualMask(nn.Module): + def __init__(self, max_positions=1024, dtype=flow.float16, *, layer_idx=0): + super().__init__() + self.dtype = dtype + self.mask = flow.full( + (max_positions, max_positions), + flow.finfo(dtype).min, + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + mask_cond = flow.arange( + self.mask.size(-1), + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + self.mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.mask.size(-1), 1), 0) + self.mask = self.mask.to(dtype) + + def forward(self, input_ids, past_length=0, attention_mask=None, input_dtype=None): + bsz, tgt_len = input_ids.size() + casual_mask = self.mask[:tgt_len, :tgt_len] + if past_length > 0: + # in case past_key_values are used, we need to add a prefix ones mask to casual mask + casual_mask = flow.cat( + [flow.ones(tgt_len, past_length, dtype=self.dtype), casual_mask], dim=-1 + ) + casual_mask = ( + casual_mask.unsqueeze(0).unsqueeze(1).expand(bsz, 1, tgt_len, tgt_len + past_length) + ) + casual_mask = casual_mask.to_global(sbp=input_ids.sbp) + if attention_mask is not None: + bsz, src_len = attention_mask.size() + attention_mask = ( + attention_mask[:, None, None, :] + .expand(bsz, 1, tgt_len, src_len) + .to(casual_mask.dtype) + ) + attention_mask = attention_mask.to_global(placement=casual_mask.placement) + casual_mask = casual_mask + attention_mask + if input_dtype is not None: + casual_mask = casual_mask.to(input_dtype) + return casual_mask + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + num_attention_heads, + is_decoder=False, + rms_norm_eps=1e-5, + max_position_embeddings=None, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.rms_norm_eps = rms_norm_eps + self.max_position_embeddings = max_position_embeddings + self.attn_mask_type = attn_mask_type + + self.layer_idx = layer_idx + self.is_decoder = is_decoder + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.init_method = init_method + if output_layer_init_method is None: + output_layer_init_method = init_method + self.output_layer_init_method = output_layer_init_method + + self.input_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.self_attn = self.build_attention() + self.post_attention_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.mlp = MLP( + self.hidden_size, + self.intermediate_size, + self.init_method, + output_layer_init_method=self.output_layer_init_method, + layer_idx=self.layer_idx, + ) + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_value=None, + cos_cached=None, + sin_cached=None, + use_cache=False, + adapter=None, + ): + hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx)) + + # hidden_states shape: (batch_size, seq_length, hidden_size) + if attention_mask is not None: + attention_mask = attention_mask.to_global( + placement=dist.get_layer_placement(self.layer_idx) + ) + + if past_key_value is not None: + if self.is_decoder: + assert len(past_key_value) == 4 + self_attn_past_key_value = past_key_value[:2] + else: + self_attn_past_key_value = past_key_value + else: + self_attn_past_key_value = None + + layernorm_output = self.input_layernorm(hidden_states) + attention_output = self.self_attn( + layernorm_output, + attention_mask=attention_mask, + past_key_value=self_attn_past_key_value, + cos_cached=cos_cached, + sin_cached=sin_cached, + use_cache=use_cache, + adapter=adapter, + ) + + if use_cache: + attention_output, presents = attention_output + + hidden_states = hidden_states + attention_output + + layernorm_output = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(layernorm_output) + + output = hidden_states + mlp_output + + if use_cache: + output = (output, presents) + return output + + def build_attention(self): + return MultiheadAttention( + self.hidden_size, + self.num_attention_heads, + self.max_position_embeddings, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + scale_mask_softmax_fusion=self.scale_mask_softmax_fusion, + attn_mask_type=self.attn_mask_type, + layer_idx=self.layer_idx, + ) + + +class LlamaModel(nn.Module): + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + cfg=None, + ): + super().__init__() + self.cfg = cfg + init_method = init_method_normal(sigma=initializer_range) + if use_scaled_init_for_output_weights: + output_layer_init_method = scaled_init_method_normal(initializer_range, hidden_layers) + else: + output_layer_init_method = init_method + + self.embed_tokens = VocabEmbedding( + vocab_size, hidden_size, init_method=init_method, amp_enabled=amp_enabled + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + hidden_size, + intermediate_size, + num_attention_heads, + rms_norm_eps=rms_norm_eps, + max_position_embeddings=max_position_embeddings, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + attn_mask_type=AttnMaskType.causal, + layer_idx=i, + ) + for i in range(hidden_layers) + ] + ) + self.norm = RMSLayerNorm(hidden_size, eps=rms_norm_eps, layer_idx=-1) + + self.adapter_query = Embedding( + cfg.adapter_len * cfg.adapter_layer, hidden_size, amp_enabled=amp_enabled + ) + + self._set_cos_sin_cache( + rotary_dim=hidden_size // num_attention_heads, + seq_len=max_position_embeddings, + dtype=flow.float32, + layer_idx=0, + ) + + def _set_cos_sin_cache(self, rotary_dim, seq_len, base=10000, dtype=None, layer_idx=0): + position = flow.arange( + 0, + rotary_dim, + 2, + dtype=dtype, + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=dist.get_layer_placement(layer_idx), + ) + inv_freq = 1.0 / (base ** (position / rotary_dim)) + + t = flow.arange( + seq_len, + dtype=inv_freq.dtype, + sbp=inv_freq.sbp, + placement=inv_freq.placement, + ) + + freqs = flow.einsum("i,j->ij", t, inv_freq) + emb = flow.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype)) + self.register_buffer("sin_cached", emb.sin().to(dtype)) + + def forward( + self, + input_ids, + attention_mask=None, + past_key_values=None, + use_cache=False, + set_cache=None, + ): + with flow.no_grad(): + if use_cache: + presents = [] + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + hidden_states = self.embed_tokens(input_ids) + + for layer, past_key_value in zip( + self.layers[: -self.cfg.adapter_layer], past_key_values[: -self.cfg.adapter_layer] + ): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + use_cache=False, + adapter=None, + ) + if use_cache: + hidden_states, present = hidden_states + presents.append(present) + + adapter_index = 0 + # [num_adapter_layer, 1, adapter_len, 4096] + adapter = self.adapter_query.weight.reshape(-1, self.cfg.adapter_len, 4096).unsqueeze(1) + for layer, past_key_value in zip( + self.layers[-self.cfg.adapter_layer :], past_key_values[-self.cfg.adapter_layer :] + ): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + use_cache=False, + adapter=adapter[adapter_index], # [1, adapter_len, 4096] + ) + adapter_index += 1 + if use_cache: + hidden_states, present = hidden_states + presents.append(present) + + hidden_states = self.norm(hidden_states) + + if use_cache: + set_cache(presents) + + return hidden_states + + +class CrossEntropyLoss(nn.Module): + def forward(self, logits: flow.Tensor, target: flow.Tensor): + assert logits.ndim == 3 + assert target.ndim == 2 + assert logits.shape[0:2] == target.shape + + target = target.to_global(placement=logits.placement) + target = target * (target >= 0) + + lm_loss = flow._C.cross_entropy( + logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=0 + ) + return lm_loss + + +class SFTLoss(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lm_loss = CrossEntropyLoss() + + def forward(self, logits, lm_labels): + lm_loss = self.lm_loss(logits, lm_labels) + lm_loss = lm_loss.mean() + return {"lm_loss": lm_loss} + + +class LlamaForCausalLM(nn.Module, Generator): + @configurable + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + cfg=None, + ): + super().__init__() + self.cfg = cfg + self.model = LlamaModel( + hidden_layers=hidden_layers, + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=rms_norm_eps, + initializer_range=initializer_range, + use_scaled_init_for_output_weights=use_scaled_init_for_output_weights, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + amp_enabled=amp_enabled, + cfg=cfg, + ) + self.casual_mask = CasualMask(max_position_embeddings, layer_idx=0) + self.lm_head = Linear(hidden_size, vocab_size, bias=False, layer_idx=-1) + self.loss_func = SFTLoss() + + self.past_key_values = [None] * hidden_layers + self.past_length = 0 + + def forward(self, input_ids, attention_mask=None, labels=None, use_cache=False): + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + attention_mask = ( + attention_mask.to_global(placement=dist.get_layer_placement(0)) + if attention_mask is not None + else attention_mask + ) + labels = ( + labels.to_global(placement=dist.get_layer_placement(0)) + if labels is not None + else labels + ) + + if use_cache and self.past_key_values[0] is not None: + self.past_length = self.past_key_values[0][0].size(-2) + else: + self.past_length = 0 + + mask = self.casual_mask( + input_ids, + past_length=self.past_length, + attention_mask=attention_mask, + input_dtype=self.lm_head.weight.dtype, + ) + + output = self.model( + input_ids, + attention_mask=mask, + past_key_values=self.past_key_values, + use_cache=use_cache, + set_cache=self.set_cache, + ) + + logits = self.lm_head(output) + + if labels is not None: + lm_loss = self.loss_func(logits, labels) + return lm_loss + else: + return {"logits": logits} + + def set_cache(self, past_key_values): + self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2] + + if past_key_values is None: + past_key_values = [None] * self.cfg.hidden_layers + + assert len(past_key_values) == self.cfg.hidden_layers, ( + f"past_key_values's length {len(past_key_values)} doesn't match " + f"num_layers:' {self.cfg.hidden_layers}" + ) + + def prepare_inputs_for_generation(self, input_ids: flow.Tensor, **kwargs): + if "attention_mask" in kwargs: + attention_mask = kwargs.pop("attention_mask").float() + attention_mask = attention_mask - 1 + attention_mask.masked_fill_(attention_mask == -1, flow.finfo(flow.float32).min) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def from_config(cls, cfg): + return { + "hidden_layers": cfg.hidden_layers, + "vocab_size": cfg.vocab_size, + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_attention_heads": cfg.num_attention_heads, + "max_position_embeddings": cfg.max_position_embeddings, + "rms_norm_eps": cfg.rms_norm_eps, + "initializer_range": cfg.initializer_range, + "use_scaled_init_for_output_weights": cfg.use_scaled_init_for_output_weights, + "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion, + "amp_enabled": cfg.amp_enabled, + "cfg": cfg, + } + + @staticmethod + def set_activation_checkpoint(model): + for module_block in model.modules(): + # Old API in OneFlow 0.8 + if hasattr(module_block, "origin"): + if isinstance(module_block.origin, LlamaDecoderLayer): + module_block.config.activation_checkpointing = True + else: + if isinstance(module_block.to(nn.Module), LlamaDecoderLayer): + module_block.to(nn.graph.GraphModule).activation_checkpointing = True diff --git a/projects/Llama3/adapter/adapter_sft.py b/projects/Llama3/adapter/adapter_sft.py new file mode 100644 index 000000000..d8825e2bb --- /dev/null +++ b/projects/Llama3/adapter/adapter_sft.py @@ -0,0 +1,97 @@ +import os + +from omegaconf import OmegaConf + +from configs.common.models.graph import graph +from configs.common.optim import optim +from configs.common.train import train +from libai.config import LazyCall +from libai.data.build import build_nlp_test_loader, build_nlp_train_loader +from libai.evaluation import PPLEvaluator +from libai.scheduler import WarmupExponentialLR +from projects.Llama3.adapter.adapter_config import cfg +from projects.Llama3.adapter.adapter_model import LlamaForCausalLM +from projects.Llama3.dataset import AlpacaDataset +from projects.Llama3.tokenizer import LlamaTokenizer + +# Hyperparameters +weight_decay = 0.1 +learning_rate = 2e-5 +max_input_length = 512 +dataset_path = "alpaca_data" +pretrained_model_path = "meta-llama/Llama-2-7b-hf" + +# graph & optim +graph["enabled"] = False +optim.update( + dict( + lr=learning_rate, + weight_decay=weight_decay, + ) +) + +# tokenize +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(LlamaTokenizer)( + pretrained_model_path=os.path.join(pretrained_model_path, "tokenizer.model") +) + +# model +cfg.use_cache = False +model = LazyCall(LlamaForCausalLM)(cfg=cfg) + +# datasets +dataloader = OmegaConf.create() +dataloader.train = LazyCall(build_nlp_train_loader)( + dataset=[ + LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer + ) + ], +) +dataloader.test = [ + LazyCall(build_nlp_test_loader)( + dataset=LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer + ), + ), +] + + +train.update( + dict( + output_dir="./sft_result", + train_micro_batch_size=8, + test_micro_batch_size=1, + train_epoch=3, + train_iter=1, + log_period=10, + warmup_ratio=2 / 5, + num_accumulation_steps=8, + rdma_enabled=False, + amp=dict(enabled=True), + activation_checkpoint=dict(enabled=True), + checkpointer=dict( + period=5000, + max_to_keep=20, + ), + dist=dict( + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=8, + pipeline_num_layers=cfg.hidden_layers, + ), + evaluation=dict( + enabled=True, + evaluator=LazyCall(PPLEvaluator)(), + eval_period=1000, + eval_iter=100, + ), + scheduler=LazyCall(WarmupExponentialLR)( + warmup_factor=0.0, + gamma=1.0, + warmup_method="linear", + ), + ) +) diff --git a/projects/Llama3/adapter/train_net.py b/projects/Llama3/adapter/train_net.py new file mode 100644 index 000000000..3f60dfcce --- /dev/null +++ b/projects/Llama3/adapter/train_net.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import random +import sys + +import numpy as np +import oneflow as flow + +import libai.utils.distributed as dist +from libai.config import LazyConfig, default_argument_parser, try_get_key +from libai.engine import DefaultTrainer, default_setup +from libai.utils.checkpoint import Checkpointer +from projects.Llama3.utils.llama_loader import LlamaLoaderHuggerFace + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + +logger = logging.getLogger("libai." + __name__) + + +def build_model(cfg): + model_loader = LlamaLoaderHuggerFace( + cfg, + cfg.cfg, + cfg.cfg.pretrained_model_path, + ) + model = model_loader.load() + + for name, param in model.named_parameters(): + if "adapter" not in name: + param.requires_grad = False + else: + param.requires_grad = True + param.data = param.data.float() + + for name, param in model.model.layers[-cfg.cfg.adapter_layer :].named_parameters(): + if "gate" in name or "adapter" in name: + param.data = param.data.float() + param.requires_grad = True + + return model + + +class LlamaTrainer(DefaultTrainer): + @classmethod + def build_model(cls, cfg): + assert try_get_key(cfg, "model") is not None, "cfg must contain `model` namespace" + # Set model fp16 option because of embedding layer `white_identity` manual + # insert for amp training if provided. + if try_get_key(cfg.model, "cfg.amp_enabled") is not None: + cfg.model.cfg.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + # In case some model define without cfg keyword. + elif try_get_key(cfg.model, "amp_enabled") is not None: + cfg.model.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + model = build_model(cfg.model) + logger = logging.getLogger(__name__) + logger.info("Model:\n{}".format(model)) + model._apply(dist.convert_to_distributed_default_setting) + return model + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + seed_for_rank = cfg.train.seed + flow.env.get_rank() + flow.manual_seed(seed_for_rank) + flow.cuda.manual_seed(seed_for_rank) + np.random.seed(seed_for_rank) + random.seed(seed_for_rank) + + if args.fast_dev_run: + cfg.train.train_epoch = 0 + cfg.train.train_iter = 20 + cfg.train.evaluation.eval_period = 10 + cfg.train.log_period = 1 + + if args.eval_only: + tokenizer = None + if try_get_key(cfg, "tokenization") is not None: + tokenizer = DefaultTrainer.build_tokenizer(cfg) + model = DefaultTrainer.build_model(cfg) + Checkpointer(model, save_dir=cfg.train.output_dir).resume_or_load( + cfg.train.load_weight, resume=args.resume + ) + if try_get_key(cfg, "graph.enabled", default=False): + model = DefaultTrainer.build_graph(cfg, model, is_train=False) + test_loader = DefaultTrainer.build_test_loader(cfg, tokenizer) + if len(test_loader) == 0: + logger.info("No dataset in dataloader.test, please set dataset for dataloader.test") + _ = DefaultTrainer.test(cfg, test_loader, model) + return + + trainer = LlamaTrainer(cfg) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + main(args) diff --git a/projects/Llama3/configs/llama_config.py b/projects/Llama3/configs/llama_config.py new file mode 100644 index 000000000..4724e5e26 --- /dev/null +++ b/projects/Llama3/configs/llama_config.py @@ -0,0 +1,61 @@ +from omegaconf import DictConfig, OmegaConf + +from libai.config import LazyCall +from projects.Llama3.llama import LlamaForCausalLM +from projects.Llama3.tokenizer import LlamaTokenizer +from configs.common.train import train + + +cfg = dict( + # Model + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=14336, + max_position_embeddings=8192, + num_attention_heads=32, + hidden_layers=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=False, + vocab_size=128256, + use_scaled_init_for_output_weights=False, + scale_mask_softmax_fusion=False, + amp_enabled=True, + # Inference + is_encoder_decoder=False, + max_length=256, + min_length=0, + do_sample=False, + early_stopping=False, + num_beams=1, + num_beam_groups=1, + diversity_penalty=0.0, + temperature=0.9, + top_k=50, + top_p=0.6, + typical_p=1.0, + repetition_penalty=1.0, + length_penalty=1.0, + no_repeat_ngram_size=0, + encoder_no_repeat_ngram_size=0, + num_return_sequences=1, + chunk_size_feed_forward=0, + output_scores=False, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + pad_token_id=0, + # train + pretrained_model_path="../../data/hf_models/meta-llama/Meta-Llama-3-8B", +) + +cfg = DictConfig(cfg) + +model = LazyCall(LlamaForCausalLM)(cfg=cfg) +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(LlamaTokenizer)( + pretrained_model_path="../../data/hf_models/meta-llama/Meta-Llama-3-8B/original/tokenizer.model" +) diff --git a/projects/Llama3/configs/llama_sft.py b/projects/Llama3/configs/llama_sft.py new file mode 100644 index 000000000..157b89879 --- /dev/null +++ b/projects/Llama3/configs/llama_sft.py @@ -0,0 +1,97 @@ +import os +from omegaconf import OmegaConf + +from libai.config import LazyCall +from libai.evaluation import PPLEvaluator +from libai.scheduler import WarmupExponentialLR +from libai.data.build import build_nlp_test_loader, build_nlp_train_loader + +from configs.common.train import train +from configs.common.models.graph import graph +from configs.common.optim import optim + +from projects.Llama3.configs.llama_config import cfg +from projects.Llama3.dataset import AlpacaDataset +from projects.Llama3.tokenizer import LlamaTokenizer +from projects.Llama3.llama import LlamaForCausalLM + + +# Hyperparameters +weight_decay = 0.1 +learning_rate = 5e-5 +dataset_path = "/data/home/wujian/work/libai/alpaca_data" +pretrained_model_path = "../../data/hf_models/meta-llama/Meta-Llama-3-8B" + +# graph & optim +graph["enabled"] = False +optim.update( + dict( + lr=learning_rate, + weight_decay=weight_decay, + ) +) + +# tokenize +tokenization = OmegaConf.create() +tokenization.make_vocab_size_divisible_by = 1 +tokenization.tokenizer = LazyCall(LlamaTokenizer)( + pretrained_model_path=os.path.join(pretrained_model_path + "/original", "tokenizer.model") +) + +# model +model = LazyCall(LlamaForCausalLM)(cfg=cfg) + +# datasets +dataloader = OmegaConf.create() +dataloader.train = LazyCall(build_nlp_train_loader)( + dataset=[ + LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer + ) + ], +) +dataloader.test = [ + LazyCall(build_nlp_test_loader)( + dataset=LazyCall(AlpacaDataset)( + path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer + ), + ), +] + + +train.update( + dict( + output_dir="./sft_result", + train_micro_batch_size=4, + test_micro_batch_size=1, + train_epoch=3, + train_iter=1, + log_period=10, + warmup_ratio=1 / 3, + num_accumulation_steps=8, + rdma_enabled=False, + amp=dict(enabled=True), + activation_checkpoint=dict(enabled=True), + checkpointer=dict( + period=5000, + max_to_keep=20, + ), + dist=dict( + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=8, + pipeline_num_layers=cfg.hidden_layers, + ), + evaluation=dict( + enabled=True, + evaluator=LazyCall(PPLEvaluator)(), + eval_period=1000, + eval_iter=1e5, + ), + scheduler=LazyCall(WarmupExponentialLR)( + warmup_factor=0.0, + gamma=1.0, + warmup_method="linear", + ), + ) +) diff --git a/projects/Llama3/dataset.py b/projects/Llama3/dataset.py new file mode 100644 index 000000000..d78efe9fe --- /dev/null +++ b/projects/Llama3/dataset.py @@ -0,0 +1,19 @@ +import oneflow as flow +from oneflow.utils.data import Dataset + +from libai.data.structures import DistTensorData, Instance + + +class AlpacaDataset(Dataset): + def __init__(self, path, tokenizer): + self.data = flow.load(path) + self.tokenizer = tokenizer + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return Instance( + input_ids=DistTensorData(self.data[index]["input_ids"]), + labels=DistTensorData(self.data[index]["labels"]), + ) diff --git a/projects/Llama3/llama.py b/projects/Llama3/llama.py new file mode 100644 index 000000000..ea1b73541 --- /dev/null +++ b/projects/Llama3/llama.py @@ -0,0 +1,647 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import oneflow as flow +from oneflow import nn + +from libai.config import configurable +from libai.inference.generator.generation_utils import Generator +from libai.layers import Linear, RMSLayerNorm, VocabEmbedding +from libai.layers.attention import AttnMaskType +from libai.models.utils import init_method_normal, scaled_init_method_normal +from libai.utils import distributed as dist + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return flow.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + + def forward(self, x, seq_len=None, cos_cached=None, sin_cached=None): + if seq_len > self.max_position_embeddings: + raise ValueError( + f"The maximum supported length is {self.max_position_embeddings}, " + f"and the current length is{seq_len}." + ) + + return ( + cos_cached[:seq_len].to_global(placement=x.placement), + sin_cached[:seq_len].to_global(placement=x.placement), + ) + + +class MLP(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + *, + layer_idx=0, + ): + super().__init__() + + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.gate_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.up_proj = Linear( + hidden_size, + intermediate_size, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.down_proj = Linear( + intermediate_size, + hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.activation_func = nn.SiLU() + + def forward(self, hidden_states): + gate_out = self.activation_func(self.gate_proj(hidden_states)) + up_out = self.up_proj(hidden_states) + output = self.down_proj(gate_out * up_out) + return output + + +class MultiheadAttention(nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + max_position_embeddings, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + if output_layer_init_method is None: + output_layer_init_method = init_method + + self.num_heads = num_attention_heads + self.head_size = hidden_size // num_attention_heads + self.attn_mask_type = attn_mask_type + + self.norm_factor = 1.0 / math.sqrt(float(self.head_size)) + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.query_key_value = Linear( + self.hidden_size, + self.hidden_size * 3, + bias=False, + parallel="col", + init_method=init_method, + layer_idx=layer_idx, + ) + + self.o_proj = Linear( + self.hidden_size, + self.hidden_size, + bias=False, + parallel="row", + init_method=output_layer_init_method, + layer_idx=layer_idx, + ) + + self.coeff = None + + rotary_dim = self.head_size + self.rotary_embed = RotaryEmbedding( + dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + ) + + def forward( + self, + hidden_states: flow.Tensor, + encoder_states: flow.Tensor = None, + attention_mask: flow.Tensor = None, + position_ids=None, + past_key_value: Tuple[flow.Tensor, flow.Tensor] = None, + cos_cached: flow.Tensor = None, + sin_cached: flow.Tensor = None, + use_cache: bool = False, + ): + if encoder_states is not None: + encoder_states = encoder_states.to_global(placement=hidden_states.placement) + + if attention_mask is not None: + attention_mask = attention_mask.to_global(placement=hidden_states.placement) + + bsz, tgt_len = hidden_states.size()[:2] + + query_key_value = self.query_key_value(hidden_states) + query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size) + query_key_value = query_key_value.permute( + 0, 2, 1, 3 + ) # [bsz, num_heads, src_len, 3 * head_size] + query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1) + + kv_seq_len = key.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_embed( + value, seq_len=kv_seq_len, cos_cached=cos_cached, sin_cached=sin_cached + ) + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) + + if past_key_value is not None: + past_key, past_value = past_key_value + key = flow.cat((past_key.type_as(key), key), dim=2) + value = flow.cat((past_value.type_as(value), value), dim=2) + + # query, key, value: [S(0), S(1)], shape: [bsz, num_heads, seq_length, head_size] + if use_cache: + past_key_value = (key, value) + + # [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)] + attention_scores = flow.matmul(query, key, transpose_b=True, alpha=self.norm_factor) + attention_weights = attention_scores + attention_mask + + attention_weights = flow.softmax(attention_weights, dim=-1) + # Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)] + context = flow.matmul(attention_weights, value) + + # Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size] + context = context.transpose(1, 2) + output = self.o_proj(context.flatten(2)) + + if use_cache: + output = (output, past_key_value) + + return output + + +class CasualMask(nn.Module): + def __init__(self, max_positions=1024, dtype=flow.float16, *, layer_idx=0): + super().__init__() + self.dtype = dtype + self.mask = flow.full( + (max_positions, max_positions), + flow.finfo(dtype).min, + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + mask_cond = flow.arange( + self.mask.size(-1), + placement=dist.get_layer_placement(layer_idx), + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + ) + self.mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.mask.size(-1), 1), 0) + self.mask = self.mask.to(dtype) + + def forward(self, input_ids, past_length=0, attention_mask=None, input_dtype=None): + bsz, tgt_len = input_ids.size() + casual_mask = self.mask[:tgt_len, :tgt_len] + if past_length > 0: + # in case past_key_values are used, we need to add a prefix ones mask to casual mask + casual_mask = flow.cat( + [flow.ones(tgt_len, past_length, dtype=self.dtype), casual_mask], dim=-1 + ) + casual_mask = ( + casual_mask.unsqueeze(0).unsqueeze(1).expand(bsz, 1, tgt_len, tgt_len + past_length) + ) + casual_mask = casual_mask.to_global(sbp=input_ids.sbp) + if attention_mask is not None: + bsz, src_len = attention_mask.size() + attention_mask = ( + attention_mask[:, None, None, :] + .expand(bsz, 1, tgt_len, src_len) + .to(casual_mask.dtype) + ) + attention_mask = attention_mask.to_global(placement=casual_mask.placement) + casual_mask = casual_mask + attention_mask + if input_dtype is not None: + casual_mask = casual_mask.to(input_dtype) + return casual_mask + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + num_attention_heads, + is_decoder=False, + rms_norm_eps=1e-5, + max_position_embeddings=None, + init_method=nn.init.xavier_normal_, + output_layer_init_method=None, + scale_mask_softmax_fusion=False, + attn_mask_type=AttnMaskType.padding, + *, + layer_idx=0, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.rms_norm_eps = rms_norm_eps + self.max_position_embeddings = max_position_embeddings + self.attn_mask_type = attn_mask_type + + self.layer_idx = layer_idx + self.is_decoder = is_decoder + + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion + + self.init_method = init_method + if output_layer_init_method is None: + output_layer_init_method = init_method + self.output_layer_init_method = output_layer_init_method + + self.input_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.self_attn = self.build_attention() + self.post_attention_layernorm = RMSLayerNorm( + self.hidden_size, eps=self.rms_norm_eps, layer_idx=self.layer_idx + ) + + self.mlp = MLP( + self.hidden_size, + self.intermediate_size, + self.init_method, + output_layer_init_method=self.output_layer_init_method, + layer_idx=self.layer_idx, + ) + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_value=None, + cos_cached=None, + sin_cached=None, + use_cache=False, + ): + hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx)) + + # hidden_states shape: (batch_size, seq_length, hidden_size) + if attention_mask is not None: + attention_mask = attention_mask.to_global( + placement=dist.get_layer_placement(self.layer_idx) + ) + + if past_key_value is not None: + if self.is_decoder: + assert len(past_key_value) == 4 + self_attn_past_key_value = past_key_value[:2] + else: + self_attn_past_key_value = past_key_value + else: + self_attn_past_key_value = None + + layernorm_output = self.input_layernorm(hidden_states) + attention_output = self.self_attn( + layernorm_output, + attention_mask=attention_mask, + past_key_value=self_attn_past_key_value, + cos_cached=cos_cached, + sin_cached=sin_cached, + use_cache=use_cache, + ) + + if use_cache: + attention_output, presents = attention_output + + hidden_states = hidden_states + attention_output + + layernorm_output = self.post_attention_layernorm(hidden_states) + + mlp_output = self.mlp(layernorm_output) + + output = hidden_states + mlp_output + + if use_cache: + output = (output, presents) + return output + + def build_attention(self): + return MultiheadAttention( + self.hidden_size, + self.num_attention_heads, + self.max_position_embeddings, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + scale_mask_softmax_fusion=self.scale_mask_softmax_fusion, + attn_mask_type=self.attn_mask_type, + layer_idx=self.layer_idx, + ) + + +class LlamaModel(nn.Module): + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + ): + super().__init__() + init_method = init_method_normal(sigma=initializer_range) + if use_scaled_init_for_output_weights: + output_layer_init_method = scaled_init_method_normal(initializer_range, hidden_layers) + else: + output_layer_init_method = init_method + + self.embed_tokens = VocabEmbedding( + vocab_size, hidden_size, init_method=init_method, amp_enabled=amp_enabled + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + hidden_size, + intermediate_size, + num_attention_heads, + rms_norm_eps=rms_norm_eps, + max_position_embeddings=max_position_embeddings, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + attn_mask_type=AttnMaskType.causal, + layer_idx=i, + ) + for i in range(hidden_layers) + ] + ) + self.norm = RMSLayerNorm(hidden_size, eps=rms_norm_eps, layer_idx=-1) + + self._set_cos_sin_cache( + rotary_dim=hidden_size // num_attention_heads, + seq_len=max_position_embeddings, + dtype=flow.float32, + layer_idx=0, + ) + + def _set_cos_sin_cache(self, rotary_dim, seq_len, base=10000, dtype=None, layer_idx=0): + position = flow.arange( + 0, + rotary_dim, + 2, + dtype=dtype, + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=dist.get_layer_placement(layer_idx), + ) + inv_freq = 1.0 / (base ** (position / rotary_dim)) + + t = flow.arange( + seq_len, + dtype=inv_freq.dtype, + sbp=inv_freq.sbp, + placement=inv_freq.placement, + ) + + freqs = flow.einsum("i,j->ij", t, inv_freq) + emb = flow.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype)) + self.register_buffer("sin_cached", emb.sin().to(dtype)) + + def forward( + self, + input_ids, + attention_mask=None, + past_key_values=None, + use_cache=False, + set_cache=None, + ): + if use_cache: + presents = [] + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + hidden_states = self.embed_tokens(input_ids) + + for layer, past_key_value in zip(self.layers, past_key_values): + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + cos_cached=self.cos_cached, + sin_cached=self.sin_cached, + use_cache=False, + ) + if use_cache: + hidden_states, present = hidden_states + presents.append(present) + + hidden_states = self.norm(hidden_states) + + if use_cache: + set_cache(presents) + + return hidden_states + + +class CrossEntropyLoss(nn.Module): + def forward(self, logits: flow.Tensor, target: flow.Tensor): + assert logits.ndim == 3 + assert target.ndim == 2 + assert logits.shape[0:2] == target.shape + + target = target.to_global(placement=logits.placement) + target = target * (target >= 0) + + lm_loss = flow._C.cross_entropy( + logits.view(-1, logits.shape[-1]), target.view(-1), ignore_index=0 + ) + return lm_loss + + +class SFTLoss(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lm_loss = CrossEntropyLoss() + + def forward(self, logits, lm_labels): + lm_loss = self.lm_loss(logits, lm_labels) + lm_loss = lm_loss.mean() + return {"lm_loss": lm_loss} + + +class LlamaForCausalLM(nn.Module, Generator): + @configurable + def __init__( + self, + hidden_layers, + vocab_size, + hidden_size, + intermediate_size, + num_attention_heads, + max_position_embeddings=1024, + rms_norm_eps=1e-5, + initializer_range=0.02, + use_scaled_init_for_output_weights=True, + scale_mask_softmax_fusion=False, + amp_enabled=False, + cfg=None, + ): + super().__init__() + self.cfg = cfg + self.model = LlamaModel( + hidden_layers=hidden_layers, + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=rms_norm_eps, + initializer_range=initializer_range, + use_scaled_init_for_output_weights=use_scaled_init_for_output_weights, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, + amp_enabled=amp_enabled, + ) + self.casual_mask = CasualMask(max_position_embeddings, layer_idx=0) + self.lm_head = Linear(hidden_size, vocab_size, bias=False, layer_idx=-1) + self.loss_func = SFTLoss() + + self.past_key_values = [None] * hidden_layers + self.past_length = 0 + + def forward(self, input_ids, attention_mask=None, labels=None, use_cache=False): + input_ids = input_ids.to_global(placement=dist.get_layer_placement(0)) + attention_mask = ( + attention_mask.to_global(placement=dist.get_layer_placement(0)) + if attention_mask is not None + else attention_mask + ) + labels = ( + labels.to_global(placement=dist.get_layer_placement(0)) + if labels is not None + else labels + ) + + if use_cache and self.past_key_values[0] is not None: + self.past_length = self.past_key_values[0][0].size(-2) + else: + self.past_length = 0 + + mask = self.casual_mask( + input_ids, + past_length=self.past_length, + attention_mask=attention_mask, + input_dtype=self.lm_head.weight.dtype, + ) + + output = self.model( + input_ids, + attention_mask=mask, + past_key_values=self.past_key_values, + use_cache=use_cache, + set_cache=self.set_cache, + ) + + logits = self.lm_head(output) + + if labels is not None: + lm_loss = self.loss_func(logits, labels) + return lm_loss + else: + return {"logits": logits} + + def set_cache(self, past_key_values): + self.past_length = 0 if past_key_values is None else past_key_values[0][0].shape[2] + + if past_key_values is None: + past_key_values = [None] * self.cfg.hidden_layers + + assert len(past_key_values) == self.cfg.hidden_layers, ( + f"past_key_values's length {len(past_key_values)} doesn't match " + f"num_layers:' {self.cfg.hidden_layers}" + ) + + def prepare_inputs_for_generation(self, input_ids: flow.Tensor, **kwargs): + if "attention_mask" in kwargs: + attention_mask = kwargs.pop("attention_mask").float() + attention_mask = attention_mask - 1 + attention_mask.masked_fill_(attention_mask == -1, flow.finfo(flow.float32).min) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def from_config(cls, cfg): + return { + "hidden_layers": cfg.hidden_layers, + "vocab_size": cfg.vocab_size, + "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, + "num_attention_heads": cfg.num_attention_heads, + "max_position_embeddings": cfg.max_position_embeddings, + "rms_norm_eps": cfg.rms_norm_eps, + "initializer_range": cfg.initializer_range, + "use_scaled_init_for_output_weights": cfg.use_scaled_init_for_output_weights, + "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion, + "amp_enabled": cfg.amp_enabled, + "cfg": cfg, + } + + @staticmethod + def set_activation_checkpoint(model): + for module_block in model.modules(): + # Old API in OneFlow 0.8 + if hasattr(module_block, "origin"): + if isinstance(module_block.origin, LlamaDecoderLayer): + module_block.config.activation_checkpointing = True + else: + if isinstance(module_block.to(nn.Module), LlamaDecoderLayer): + module_block.to(nn.graph.GraphModule).activation_checkpointing = True diff --git a/projects/Llama3/pipeline.py b/projects/Llama3/pipeline.py new file mode 100644 index 000000000..4b65d2895 --- /dev/null +++ b/projects/Llama3/pipeline.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import click + +from libai.inference.basic import BasePipeline +from libai.utils import distributed as dist + + +class TextGenerationPipeline(BasePipeline): + def load_pretrain_weight(self, libai_cfg_model, model_path, mode="huggingface"): + """load pretrained model. + + Args: + libai_cfg_model (libai.models): Lazy config Model in Libai, you can import it + by `from libai.config.configs.common.models.bert + import pretrain_model as libai_cfg_model` + model_path (str): The directory path of pretrained model, + """ + if mode == "huggingface": + from projects.Llama.utils.llama_loader import LlamaLoaderHuggerFace + + model_loader = LlamaLoaderHuggerFace( + libai_cfg_model, + libai_cfg_model.cfg, + model_path, + ) + model = model_loader.load() + model.eval() + return model + + elif mode == "libai": + from projects.Llama.utils.llama_loader import LlamaLoaderLiBai + + model_loader = LlamaLoaderLiBai( + libai_cfg_model, + libai_cfg_model.cfg, + model_path, + ) + model = model_loader.load() + model.eval() + return model + + elif mode == "random": + from libai.engine import DefaultTrainer + + return DefaultTrainer.build_model(self.cfg) + else: + raise NotImplementedError + + def _parse_parameters(self, **pipeline_parameters): + preprocess_params = {} + forward_params = {**pipeline_parameters} + postprocess_params = {} + + return preprocess_params, forward_params, postprocess_params + + def preprocess(self, inputs, **kwargs) -> dict: + # tokenizer encoderW + inputs = self.tokenizer.tokenize(inputs, add_bos=True, padding=True, device=self.device) + inputs = { + "input_ids": inputs, + } + + return inputs + + def forward(self, inputs, **kwargs) -> dict: + outputs = self.model.generate(inputs["input_ids"], max_length=50, **kwargs) + return {"return_ids": outputs} + + def postprocess(self, model_output_dict, **kwargs) -> dict: + return_ids = model_output_dict["return_ids"] + records = [ + {"generated_text": self.tokenizer.decode(return_ids[i])} + for i in range(return_ids.size(0)) + ] + return records + + +@click.command() +@click.option( + "--config_file", + default="projects/Llama/configs/llama_config.py", + help="Path to the configuration file.", +) +@click.option("--model_path", default=None, help="Path to the model checkpoint.") +@click.option( + "--mode", + default="libai", + help="Mode for the dataloader pipeline, e.g., 'libai' or 'huggingface'.", +) +@click.option( + "--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'." +) +def main(config_file, model_path, mode, device): + pipeline = TextGenerationPipeline( + config_file, + data_parallel=1, + tensor_parallel=1, + pipeline_parallel=1, + pipeline_num_layers=32, + model_path=model_path, + mode=mode, + device=device, + ) + + text = [ + "Give three tips for staying healthy.", + ] + output = pipeline(inputs=text) + if dist.is_main_process(): + print(output) + + +if __name__ == "__main__": + main() diff --git a/projects/Llama3/tokenizer.py b/projects/Llama3/tokenizer.py new file mode 100644 index 000000000..a8b77267b --- /dev/null +++ b/projects/Llama3/tokenizer.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import oneflow as flow +# import sentencepiece as spm +import tiktoken +from pathlib import Path +from tiktoken.load import load_tiktoken_bpe + +import libai.utils.distributed as dist + + +class LlamaTokenizer: + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + def __init__( + self, + pretrained_model_path, + bos_token="", + eos_token="", + pad_token="", + bos_token_id=None, + eos_token_id=None, + ): + mergeable_ranks = load_tiktoken_bpe(pretrained_model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "", + "", + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.tik_model = tiktoken.Encoding( + name=Path(pretrained_model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.bos_token = bos_token + self.eos_token = eos_token + self.pad_token = pad_token + self.bos_token_id = self.special_tokens[""] + self.eos_token_id = self.special_tokens[""] + self.pad_token_id = 0 + + @property + def vocab_size(self): + return self.tik_model.n_vocab + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + return vocab + + def encode(self, text): + tokens = self.tik_model.encode(text) + return tokens + + def tokenize( + self, + text, + add_bos=False, + add_eos=False, + padding=False, + device="cuda", + max_length=4096, + **kwargs + ): + if isinstance(text, str): + tokens = [self.tik_model.encode(text)[:max_length]] + + if isinstance(text, list): + tokens = [self.tik_model.encode(s)[:max_length] for s in text] + if padding: + max_length = max([len(i) for i in tokens]) + tokens = [t + (max_length - len(t)) * [self.pad_token_id] for t in tokens] + + if add_bos: + tokens = [[self.bos_token_id] + token for token in tokens] + if add_eos: + tokens = [token + [self.eos_token_id] for token in tokens] + + if device: + sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])) + placement = kwargs.get("placement", flow.placement(device, [0])) + return_token_ids = flow.tensor(tokens, sbp=sbp, placement=placement, dtype=flow.long) + else: + return_token_ids = flow.tensor(tokens, dtype=flow.long) + return return_token_ids + + def decode(self, tokens): + if isinstance(tokens, flow.Tensor): + tokens = tokens.tolist() + return self.tik_model.decode(tokens) + + def convert_token_to_id(self, token): + return self.tik_model.encode_single_token(token) + + def convert_id_to_token(self, index): + return self.tik_model.decode_single_token_bytes(index) diff --git a/projects/Llama3/train_net.py b/projects/Llama3/train_net.py new file mode 100644 index 000000000..3f730b272 --- /dev/null +++ b/projects/Llama3/train_net.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import random +import sys + +import numpy as np +import oneflow as flow + +import libai.utils.distributed as dist +from libai.config import LazyConfig, default_argument_parser, try_get_key +from libai.engine import DefaultTrainer, default_setup +from libai.utils.checkpoint import Checkpointer +from projects.Llama3.utils.llama_loader import LlamaLoaderHuggerFace + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + +logger = logging.getLogger("libai." + __name__) + + +def build_model(cfg): + model_loader = LlamaLoaderHuggerFace( + cfg, + cfg.cfg, + cfg.cfg.pretrained_model_path, + ) + model = model_loader.load() + return model + + +class LlamaTrainer(DefaultTrainer): + @classmethod + def build_model(cls, cfg): + assert try_get_key(cfg, "model") is not None, "cfg must contain `model` namespace" + # Set model fp16 option because of embedding layer `white_identity` manual + # insert for amp training if provided. + if try_get_key(cfg.model, "cfg.amp_enabled") is not None: + cfg.model.cfg.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + # In case some model define without cfg keyword. + elif try_get_key(cfg.model, "amp_enabled") is not None: + cfg.model.amp_enabled = cfg.train.amp.enabled and cfg.graph.enabled + model = build_model(cfg.model) + logger = logging.getLogger(__name__) + logger.info("Model:\n{}".format(model)) + model._apply(dist.convert_to_distributed_default_setting) + return model + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + seed_for_rank = cfg.train.seed + flow.env.get_rank() + flow.manual_seed(seed_for_rank) + flow.cuda.manual_seed(seed_for_rank) + np.random.seed(seed_for_rank) + random.seed(seed_for_rank) + + if args.fast_dev_run: + cfg.train.train_epoch = 0 + cfg.train.train_iter = 20 + cfg.train.evaluation.eval_period = 10 + cfg.train.log_period = 1 + + if args.eval_only: + tokenizer = None + if try_get_key(cfg, "tokenization") is not None: + tokenizer = DefaultTrainer.build_tokenizer(cfg) + model = DefaultTrainer.build_model(cfg) + Checkpointer(model, save_dir=cfg.train.output_dir).resume_or_load( + cfg.train.load_weight, resume=args.resume + ) + if try_get_key(cfg, "graph.enabled", default=False): + model = DefaultTrainer.build_graph(cfg, model, is_train=False) + test_loader = DefaultTrainer.build_test_loader(cfg, tokenizer) + if len(test_loader) == 0: + logger.info("No dataset in dataloader.test, please set dataset for dataloader.test") + _ = DefaultTrainer.test(cfg, test_loader, model) + return + + trainer = LlamaTrainer(cfg) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + main(args) diff --git a/projects/Llama3/utils/eval_adapter.py b/projects/Llama3/utils/eval_adapter.py new file mode 100644 index 000000000..f7ab0f053 --- /dev/null +++ b/projects/Llama3/utils/eval_adapter.py @@ -0,0 +1,177 @@ +import json +from pathlib import Path +from typing import Dict, List, Optional + +import oneflow as flow + +flow.mock_torch.enable(lazy=True) + +from lm_eval import evaluator, tasks # noqa +from lm_eval.base import BaseLM # noqa +from omegaconf import DictConfig # noqa + +import libai.utils.distributed as dist # noqa +from libai.config import instantiate # noqa +from projects.Llama3.configs.llama_config import cfg, tokenization # noqa +from projects.Llama3.llama import LlamaForCausalLM # noqa +from projects.Llama3.utils.llama_loader import LlamaLoaderHuggerFace, LlamaLoaderLiBai # noqa + + +class EvalHarnessBase(BaseLM): + def __init__(self, model, tokenizer, batch_size: int, cfg: dict): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.batch_size_per_gpu = batch_size + self.cfg = cfg + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + pass + + @property + def eot_token_id(self): + return self.tokenizer.eos_token_id + + @property + def max_length(self): + return self.cfg.max_position_embeddings + + @property + def vocab_size(self): + return self.cfg.vocab_size + + @property + def max_gen_toks(self): + return self.cfg.get("max_length", 256) + + @property + def batch_size(self): + return self.batch_size_per_gpu * dist.get_world_size() + + @property + def device(self): + return flow.device("cuda:0") + + def tok_encode(self, string: str) -> List[int]: + return self.tokenizer.tokenize(string, add_bos=False, add_eos=False).squeeze(0).tolist() + + def tok_decode(self, tokens: List[int]) -> str: + return self.tokenizer.decode(tokens) + + @flow.inference_mode() + def _model_call(self, inps): + inps = inps.to_global( + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=dist.get_layer_placement(0), + ) + return self.model(inps)["logits"].to_local().to(flow.float32) + + def _model_generate(self, context, max_length, eos_token_id) -> flow.Tensor: + # this only supports batch size 1 + assert context.shape[0] == 1 + out = self.model.generate(context[0], max_length, eos_id=eos_token_id) + return out.unsqueeze(0) + + @flow.inference_mode() + def run_eval( + self, + eval_tasks: List[str], + num_fewshot: int, + limit: Optional[int], + bootstrap_iters: int, + ) -> Dict: + import fnmatch + + def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return list(task_names) + + eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) + print(f"Found tasks: {eval_tasks}") + + if dist.is_main_process() == 0: + tasks.get_task_dict(eval_tasks) + dist.synchronize() + tasks.get_task_dict(eval_tasks) + + lm = self + + results = evaluator.evaluate( + lm=lm, + task_dict=tasks.get_task_dict(eval_tasks), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + results["config"] = dict( + model="llama", + batch_size=self.batch_size, + device=str(self.device), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + return results + + +@flow.inference_mode() +def run_eval_harness( + model, + tokenizer, + eval_tasks: List[str] = [ + "hellaswag", + ], + save_filepath: Optional[Path] = None, + num_fewshot: int = 0, + limit: Optional[int] = None, + bootstrap_iters: int = 100000, + dtype=flow.float16, + cfg=None, +): + model.eval() + model = model.to(dtype) + eval_harness = EvalHarnessBase(model, tokenizer, 1, cfg) + + results = eval_harness.run_eval(eval_tasks, num_fewshot, limit, bootstrap_iters) + if save_filepath is None: + print(results) + else: + print(f"Saving results to {str(save_filepath)!r}") + data = json.dumps(results) + with open(save_filepath, "w") as fw: + fw.write(data) + + +if __name__ == "__main__": + parallel_config = DictConfig( + dict( + data_parallel_size=1, + tensor_parallel_size=8, + pipeline_parallel_size=1, + pipeline_num_layers=32, + device_type="cuda", + ) + ) + dist.setup_dist_util(parallel_config) + + tokenizer = instantiate(tokenization.tokenizer) + + # ----- load huggingface checkpoint ----- + # load_func = LlamaLoaderHuggerFace( + # model=LlamaForCausalLM, + # libai_cfg=cfg, + # pretrained_model_path="", + # ) + + # ----- load oneflow checkpoint ----- + load_func = LlamaLoaderLiBai( + model=LlamaForCausalLM, + libai_cfg=cfg, + pretrained_model_path="", + ) + model = load_func.load() + run_eval_harness(model, tokenizer, cfg=cfg) diff --git a/projects/Llama3/utils/llama_loader.py b/projects/Llama3/utils/llama_loader.py new file mode 100644 index 000000000..c46cb480a --- /dev/null +++ b/projects/Llama3/utils/llama_loader.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2021 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import oneflow as flow + +from libai.models.utils.model_loader.base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai + + +class LlamaLoaderHuggerFace(ModelLoaderHuggerFace): + def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): + super().__init__(model, libai_cfg, pretrained_model_path, **kwargs) + + self.base_model_prefix_1 = "model" + self.base_model_prefix_2 = "model" + if not pretrained_model_path: + self.pretrained_model_path = libai_cfg.pretrained_model_path + + def _convert_state_dict(self, flow_state_dict, cfg): + """Convert state_dict's keys to match model. + + Args: + flow_state_dict (OrderedDict): model state dict. + cfg (dict): model's default config dict in LiBai. + + Returns: + OrderedDict: flow state dict. + """ + # The converted checkpoint. + oneflow_state_dict = flow_state_dict.copy() + old_keys = list(oneflow_state_dict.keys()) + + # Get configs + num_attention_heads = cfg.get("num_attention_heads") + hidden_size = cfg.get("hidden_size") + head_size = int(hidden_size // num_attention_heads) + + new_key_qkv = "model.layers.{}.self_attn.query_key_value.weight" + old_key_qkv = "model.layers.{}.self_attn.{}.weight" + for layer_idx in range(cfg.get("hidden_layers")): + query = old_key_qkv.format(layer_idx, "q_proj") + key = old_key_qkv.format(layer_idx, "k_proj") + value = old_key_qkv.format(layer_idx, "v_proj") + q = oneflow_state_dict[query] + k = oneflow_state_dict[key] + v = oneflow_state_dict[value] + qkv = flow.cat([q, k, v], dim=0) + qkv = self._fix_qkv_ordering(qkv, head_size, num_attention_heads, hidden_size) + oneflow_state_dict[new_key_qkv.format(layer_idx)] = qkv + oneflow_state_dict.pop(query) + oneflow_state_dict.pop(key) + oneflow_state_dict.pop(value) + + for k in old_keys: + if "inv_freq" in k: + oneflow_state_dict.pop(k) + + return oneflow_state_dict + + def _load_config_from_json(self, config_file): + """load config from `config.json`, and update default config. + + Args: + config_file (str): Path of config file. + """ + with open(config_file, mode="r", encoding="utf-8") as f: + cfg_dict = json.load(f) + + # update libai_cfg by config.json + self._update_cfg("hidden_layers", cfg_dict["num_hidden_layers"]) + self._update_cfg("hidden_size", cfg_dict["hidden_size"]) + self._update_cfg("num_attention_heads", cfg_dict["num_attention_heads"]) + self._update_cfg("max_position_embeddings", cfg_dict["max_position_embeddings"]) + self._update_cfg("intermediate_size", cfg_dict["intermediate_size"]) + self._update_cfg("rms_norm_eps", cfg_dict["rms_norm_eps"]) + self._update_cfg("vocab_size", cfg_dict["vocab_size"]) + self._update_cfg("initializer_range", cfg_dict["initializer_range"]) + self._update_cfg( + "ffn_hidden_size", + cfg_dict.get("n_inner") + if cfg_dict.get("n_inner") is not None + else 4 * self.libai_cfg["hidden_size"], + ) + + # update libai_cfg by kwargs + for k, v in self.kwargs.items(): + self._update_cfg(k, v) + + self._update_cfg_log() + + +class LlamaLoaderLiBai(ModelLoaderLiBai): + def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): + super().__init__(model, libai_cfg, pretrained_model_path, **kwargs) + self.base_model_prefix_2 = "model" + if not pretrained_model_path: + self.pretrained_model_path = libai_cfg.pretrained_model_path diff --git a/projects/Llama3/utils/prepare_alpaca.py b/projects/Llama3/utils/prepare_alpaca.py new file mode 100644 index 000000000..2e5a045e5 --- /dev/null +++ b/projects/Llama3/utils/prepare_alpaca.py @@ -0,0 +1,162 @@ +"""Implementation derived from https://github.com/tloen/alpaca-lora""" +import copy +import json +import math +import os +from pathlib import Path +from typing import Optional + +import oneflow as flow +import requests +from oneflow.utils.data import random_split +from tqdm import tqdm + +from libai.config import instantiate +from libai.utils.logger import setup_logger +from projects.Llama3.configs.llama_config import tokenization + +logger = setup_logger() + + +def prepare( + destination_path: Path = Path("alpaca_data"), + checkpoint_dir: Path = Path("~/data/hf_models/meta-llama/Meta-Llama-3-8B"), + test_split_fraction: float = 0.03865, # to get exactly 2000 test samples, + seed: int = 42, + mask_inputs: bool = False, # as in alpaca-lora + data_file_name: str = "alpaca_data_cleaned_archive.json", + data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", # noqa + ignore_index: int = -1, + max_seq_length: Optional[int] = 512, +) -> None: + """Prepare the Alpaca dataset for instruction tuning. + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open(os.path.join(checkpoint_dir, "config.json"), "r", encoding="utf-8") as file: + config = json.load(file) + max_seq_length = config["max_position_embeddings"] + + destination_path.mkdir(parents=True, exist_ok=True) + data_file_path = destination_path / data_file_name + logger.info("Loading data file...") + download_if_missing(data_file_path, data_file_url) + with open(data_file_path, "r", encoding="utf-8") as file: + data = json.load(file) + + logger.info("Loading tokenizer...") + tokenizer = instantiate(tokenization.tokenizer) + + # Partition the dataset into train and test + num_of_test_samples = math.floor(test_split_fraction * len(data)) + num_of_train_samples = len(data) - num_of_test_samples + train_set, test_set = random_split( + data, + [num_of_train_samples, num_of_test_samples], + generator=flow.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + logger.info(f"train has {len(train_set):,} samples") + logger.info(f"test has {len(test_set):,} samples") + + logger.info("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + ) + for sample in tqdm(train_set) + ] + flow.save(train_set, destination_path / "train") + + logger.info("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + ) + for sample in tqdm(test_set) + ] + flow.save(test_set, destination_path / "test") + + max_length = max([i["input_ids"].shape[0] for i in train_set]) + logger.info("Max length of training dataset: {}".format(max_length)) + + +def download_if_missing(file_path: Path, file_url: str) -> None: + """Downloads the raw json data file and saves it in the given destination.""" + if file_path.exists() and file_path.stat().st_size > 0: + return + with open(file_path, "w", encoding="utf-8") as f: + f.write(requests.get(file_url).text) + + +def prepare_sample(example: dict, tokenizer, max_length: int) -> dict: + """Processes a single sample. + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + + prompt = tokenizer.tokenize(full_prompt, add_bos=True, add_eos=False, device="cpu")[0] + example = tokenizer.tokenize( + full_prompt_and_response, add_bos=True, add_eos=True, device=None, + # device="cpu" + )[0] + + padding = max_length - example.shape[0] + if padding > 0: + example = flow.cat((example, flow.zeros(padding, dtype=flow.long) - 1)) + elif padding < 0: + example = example[:max_length] + labels = copy.deepcopy(example) + labels[: len(prompt)] = -1 + example_mask = example.ge(0) + label_mask = labels.ge(0) + example[~example_mask] = 0 + labels[~label_mask] = -1 + example = example[:-1] + labels = labels[1:] + example_mask = flow.where( + example_mask, flow.tensor(0, dtype=flow.float), flow.tensor(-float("inf")) + ) + example_mask = example_mask[:-1] + return { + "input_ids": example, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " # noqa + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" # noqa + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + prepare() From 702936a2266abecd746a4e2e5991b959352c5e83 Mon Sep 17 00:00:00 2001 From: Lusfie <15176063690@163.com> Date: Wed, 23 Oct 2024 17:38:58 +0800 Subject: [PATCH 2/2] fix: torch_model trans oneflow_model --- projects/Llama3/utils/llama_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/Llama3/utils/llama_loader.py b/projects/Llama3/utils/llama_loader.py index c46cb480a..4248a672f 100644 --- a/projects/Llama3/utils/llama_loader.py +++ b/projects/Llama3/utils/llama_loader.py @@ -57,7 +57,7 @@ def _convert_state_dict(self, flow_state_dict, cfg): q = oneflow_state_dict[query] k = oneflow_state_dict[key] v = oneflow_state_dict[value] - qkv = flow.cat([q, k, v], dim=0) + qkv = flow.cat([q, k, k, k, k, v, v, v, v], dim=0) # in Llama3, num_attention_heads / num_key_value_heads = 4 qkv = self._fix_qkv_ordering(qkv, head_size, num_attention_heads, hidden_size) oneflow_state_dict[new_key_qkv.format(layer_idx)] = qkv oneflow_state_dict.pop(query)