Skip to content

Commit 40748b8

Browse files
authored
Merge pull request #1314 from solopku/patch-18
Add implementations for the PeftModel class
2 parents c9948c3 + 3822e42 commit 40748b8

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import model
21+
from singa_peft.peft_config import PeftConfig
22+
from singa_peft.peft_registry import PeftRegistry
23+
24+
25+
class PeftModel(model.Model):
26+
"""
27+
PeftModel: modify the base model based on the peft config. A Wrapper of model and tuner.
28+
"""
29+
def __init__(self, base_model: model.Model, peft_config: PeftConfig):
30+
r"""
31+
Args:
32+
base_model: the base model
33+
peft_config: the config of peft
34+
"""
35+
super().__init__()
36+
self.base_model = base_model
37+
self.peft_config = peft_config
38+
self.peft_type = peft_config.peft_type
39+
self.dimension = self.base_model.dimension
40+
# Get the injected tuner class based on peft_type
41+
cls = PeftRegistry.get_tuner(self.peft_type)
42+
self.tuner = cls(peft_config)
43+
# Inject adapter into base_model
44+
self.base_model = self.tuner.inject(base_model)
45+
46+
def forward(self, inputs):
47+
return self.base_model.forward(inputs)
48+
49+
def train_one_batch(self, x, y, dist_option, spars):
50+
return self.base_model.train_one_batch(x, y, dist_option, spars)
51+
52+
def set_optimizer(self, optimizer):
53+
self.base_model.set_optimizer(optimizer)
54+
55+
def compile(self, inputs, is_train=True, use_graph=False, sequential=False):
56+
self.base_model.compile(inputs, is_train, use_graph, sequential)
57+
58+
def train(self, mode=True):
59+
super().train(mode)
60+
self.base_model.train(mode)
61+
62+
def eval(self):
63+
super().eval()
64+
self.base_model.eval()
65+
66+
def merge_weights(self, mode=True):
67+
self.tuner.merge_weights(self.base_model, mode)
68+
69+
def get_params(self):
70+
params = self.base_model.get_params()
71+
return params
72+
73+
def set_params(self, params):
74+
self.base_model.set_params(params)
75+
76+
77+
def get_peft_model(base_model: model.Model, peft_config: PeftConfig):
78+
r"""
79+
Args:
80+
base_model: the base model
81+
peft_config: the config of peft
82+
83+
Returns: a peft model based on peft config
84+
"""
85+
peft_model = PeftModel(base_model, peft_config)
86+
return peft_model

0 commit comments

Comments
 (0)