|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import json |
| 16 | +import os |
| 17 | +from dataclasses import asdict, dataclass, field |
| 18 | +from typing import List, Optional, Union |
| 19 | + |
| 20 | +from ...utils.env import LOKR_CONFIG_NAME |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class LoKrConfig: |
| 25 | + """ |
| 26 | + This is the configuration class to store the configuration of a [`LoKrModel`]. |
| 27 | + Convention of LoKrModel: W1 can be named as scaling matrix, W2 can be named as adapter matrix. |
| 28 | + Args: |
| 29 | + target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. |
| 30 | + trainable_modules (`List[str]`): The names of the modules to train when applying Lora. |
| 31 | + lokr_alpha (`float`): The alpha parameter for Lora scaling. |
| 32 | + merge_weights (`bool`): |
| 33 | + Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode. |
| 34 | + """ |
| 35 | + |
| 36 | + base_model_name_or_path: Optional[str] = field( |
| 37 | + default=None, metadata={"help": "The name of the base model to use."} |
| 38 | + ) |
| 39 | + target_modules: Optional[Union[List[str], str]] = field( |
| 40 | + default=None, |
| 41 | + metadata={ |
| 42 | + "help": "List of module names or regex expression of the module names to replace with LoKr." |
| 43 | + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " |
| 44 | + }, |
| 45 | + ) |
| 46 | + trainable_modules: Optional[List[str]] = field( |
| 47 | + default=None, |
| 48 | + metadata={ |
| 49 | + "help": "List of module names or regex expression of the module names to train when applying with LoKr." |
| 50 | + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " |
| 51 | + }, |
| 52 | + ) |
| 53 | + trainable_bias: Optional[str] = field( |
| 54 | + default=None, metadata={"help": "Define trainable bias parameters for the Lora model."} |
| 55 | + ) |
| 56 | + lokr_dim: int = field(default=8, metadata={"help": "Lora dimention in LoKr dimension, for adapter matrix"}) |
| 57 | + factor: int = field(default=-1, metadata={"help": "Determine the decomposition size of LoKr matrices"}) |
| 58 | + decompose_both: bool = field( |
| 59 | + default=False, |
| 60 | + metadata={"help": "Determine whether to decomposed both Scaling Matrix and adapter matrix together"}, |
| 61 | + ) |
| 62 | + lokr_alpha: float = field( |
| 63 | + default=0.0, metadata={"help": "Determine the scaling of adapter weight, follow lokr convention"} |
| 64 | + ) |
| 65 | + merge_weight: bool = field( |
| 66 | + default=False, metadata={"help": "Merge weights of the original model and the Lokr model"} |
| 67 | + ) |
| 68 | + tensor_parallel_degree: int = field(default=-1, metadata={"help": "-1 for not use tensor parallel"}) |
| 69 | + dtype: Optional[str] = field(default=None, metadata={"help": "The data type of tensor"}) |
| 70 | + |
| 71 | + @property |
| 72 | + def __dict__(self): |
| 73 | + return asdict(self) |
| 74 | + |
| 75 | + def to_dict(self): |
| 76 | + return self.__dict__ |
| 77 | + |
| 78 | + @property |
| 79 | + def scaling(self): |
| 80 | + if not (self.lokr_alpha or self.lokr_dim): |
| 81 | + return 1.0 |
| 82 | + return self.lokr_alpha / self.lokr_dim |
| 83 | + |
| 84 | + def save_pretrained(self, save_directory): |
| 85 | + r""" |
| 86 | + This method saves the configuration of your adapter model in a directory. |
| 87 | + Args: |
| 88 | + save_directory (`str`): |
| 89 | + The directory where the configuration will be saved. |
| 90 | + """ |
| 91 | + if os.path.isfile(save_directory): |
| 92 | + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") |
| 93 | + |
| 94 | + os.makedirs(save_directory, exist_ok=True) |
| 95 | + |
| 96 | + output_dict = self.__dict__ |
| 97 | + output_dict["scaling"] = self.scaling |
| 98 | + output_path = os.path.join(save_directory, LOKR_CONFIG_NAME) |
| 99 | + |
| 100 | + # save it |
| 101 | + with open(output_path, "w") as writer: |
| 102 | + writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) |
| 103 | + |
| 104 | + @classmethod |
| 105 | + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| 106 | + r""" |
| 107 | + This method loads the configuration of your adapter model from a directory. |
| 108 | + Args: |
| 109 | + pretrained_model_name_or_path (`str`): |
| 110 | + The directory or the hub-id where the configuration is saved. |
| 111 | + **kwargs: |
| 112 | + Additional keyword arguments passed along to the child class initialization. |
| 113 | + """ |
| 114 | + if os.path.isfile(os.path.join(pretrained_model_name_or_path, LOKR_CONFIG_NAME)): |
| 115 | + config_file = os.path.join(pretrained_model_name_or_path, LOKR_CONFIG_NAME) |
| 116 | + else: |
| 117 | + raise ValueError(f"Can't find lokr_config.json at '{pretrained_model_name_or_path}'") |
| 118 | + |
| 119 | + loaded_attributes = cls.from_json_file(config_file) |
| 120 | + loaded_attributes.pop("scaling", None) |
| 121 | + |
| 122 | + config = cls(**kwargs) |
| 123 | + |
| 124 | + for key, value in loaded_attributes.items(): |
| 125 | + if hasattr(config, key): |
| 126 | + setattr(config, key, value) |
| 127 | + |
| 128 | + return config |
| 129 | + |
| 130 | + @classmethod |
| 131 | + def from_json_file(cls, path_json_file): |
| 132 | + r""" |
| 133 | + Loads a configuration file from a json file. |
| 134 | + Args: |
| 135 | + path_json_file (`str`): |
| 136 | + The path to the json file. |
| 137 | + """ |
| 138 | + with open(path_json_file, "r") as file: |
| 139 | + json_object = json.load(file) |
| 140 | + |
| 141 | + return json_object |
0 commit comments