Skip to content

Commit 2334b4d

Browse files
authored
[LLM]add unify llm (#6695)
* add unify llm * add * tiny fix * delete unused files * add quant * add quant * add json * pp not yet support * update according to comments * add 4d explaination * fix * tiny * update * tiny * tiny
1 parent 46e5551 commit 2334b4d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+1004
-6641
lines changed

llm/README.md

Lines changed: 304 additions & 22 deletions
Large diffs are not rendered by default.

llm/causallm/argument.py renamed to llm/argument.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class DataArgument:
2727
)
2828
intokens: bool = field(default=False, metadata={"help": "Whether to use InTokens data stream"})
2929
intokens_max_length: int = field(
30-
default=1024,
30+
default=2048,
3131
metadata={"help": "The max length for InTokens data stream. Only effective when intokens is True"},
3232
)
3333

@@ -43,9 +43,6 @@ class ModelArgument:
4343
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
4444
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
4545
lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})
46-
lora_merge_weights: bool = field(
47-
default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
48-
)
4946

5047
# prefix tuning related parameters
5148
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
@@ -54,38 +51,41 @@ class ModelArgument:
5451

5552
@dataclass
5653
class QuantArgument:
57-
quant_type: str = field(default="A8W8", metadata={"help": "Quantization type. Supported values: A8W8, W4,A8W4"})
54+
quant_type: str = field(
55+
default="A8W8", metadata={"help": "Quantization type. Supported values: A8W8, WINT4,WINT8"}
56+
)
5857

5958
# QAT related parameters
59+
# Not Yet support
6060
do_qat: bool = field(default=False, metadata={"help": "Whether to use QAT technique"})
6161

62-
# GPTQ related parameters
63-
do_gptq: bool = field(default=False, metadata={"help": "Whether to use GPTQ"})
64-
gptq_step: int = field(default=8, metadata={"help": "Step for GPTQ"})
65-
6662
# PTQ related parameters
6763
do_ptq: bool = field(default=False, metadata={"help": "Whether to use PTQ"})
68-
ptq_step: int = field(default=8, metadata={"help": "Step for PTQ"})
64+
ptq_step: int = field(default=32, metadata={"help": "Step for PTQ"})
6965

7066
shift: bool = field(default=False, metadata={"help": "Whether to use Shift"})
7167
shift_all_linears: bool = field(default=False, metadata={"help": "Whether to shift all linears"})
7268
shift_sampler: str = field(
7369
default="ema", metadata={"help": "The name of shift sampler, choosen from ['ema', 'none']"}
7470
)
75-
shift_step: int = field(default=8, metadata={"help": "Sample steps when shift"})
71+
shift_step: int = field(default=32, metadata={"help": "Sample steps when shift"})
7672

7773
smooth: bool = field(default=False, metadata={"help": "Whether to use Smooth"})
7874
smooth_all_linears: bool = field(default=False, metadata={"help": "Whether to smooth all linears"})
7975
smooth_sampler: str = field(
8076
default="none", metadata={"help": "The name of smooth sampler, choosen from ['multi_step','none']"}
8177
)
82-
smooth_step: int = field(default=8, metadata={"help": "Sample steps when smooth"})
78+
smooth_step: int = field(default=32, metadata={"help": "Sample steps when smooth"})
8379
smooth_piecewise_search: bool = field(
8480
default=False, metadata={"help": "The number of piece in piecewise search for smooth strategy."}
8581
)
86-
smooth_k_piece: int = field(default=6, metadata={"help": "Number of pieces for K-search"})
82+
smooth_k_piece: int = field(default=3, metadata={"help": "Number of pieces for K-search"})
8783
smooth_search_piece: bool = field(default=False, metadata={"help": "Whether search k_piece when piecewise search"})
8884

85+
# GPTQ related parameters
86+
do_gptq: bool = field(default=False, metadata={"help": "Whether to use GPTQ"})
87+
gptq_step: int = field(default=8, metadata={"help": "Step for GPTQ"})
88+
8989

9090
@dataclass
9191
class GenerateArgument:
@@ -98,8 +98,3 @@ class GenerateArgument:
9898
top_p: float = field(
9999
default=1.0, metadata={"help": "The cumulative probability for top-p-filtering in the sampling strategy."}
100100
)
101-
num_beams: int = field(default=1, metadata={"help": "The number of beams in the beam_search strategy."})
102-
decode_strategy: str = field(default="sampling", metadata={"help": "The decoding strategy in generation."})
103-
repetition_penalty: float = field(
104-
default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
105-
)

llm/bloom/README.md

Lines changed: 0 additions & 265 deletions
This file was deleted.

0 commit comments

Comments
 (0)