Skip to content

Commit fbfb4a3

Browse files
authored
Merge pull request #1304 from zmeihui/25-5-13-dev
Add the peft registry module
2 parents a3de028 + 74f4922 commit fbfb4a3

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
class PeftRegistry:
21+
"""
22+
PeftRegistry: the registry class for peft method
23+
"""
24+
25+
_tuners = {}
26+
27+
@classmethod
28+
def register(cls, tuner_name):
29+
r"""
30+
Register the Tuner class decorator
31+
Args:
32+
tuner_name: the name of the Tuner
33+
34+
Returns: the class of decorator
35+
"""
36+
def decorator(tuner_class):
37+
cls._tuners[tuner_name] = tuner_class
38+
return tuner_class
39+
return decorator
40+
41+
@classmethod
42+
def get_tuner(cls, tuner_name):
43+
r"""
44+
Get the Tuner class by name
45+
Args:
46+
tuner_name: the name of the Tuner
47+
48+
Returns: the class of the Tuner
49+
"""
50+
if tuner_name not in cls._tuners:
51+
raise ValueError(f"Unsupported peft method: {tuner_name}")
52+
return cls._tuners[tuner_name]

0 commit comments

Comments
 (0)