Skip to content

Commit fd4fb6c

Browse files
authored
Merge pull request #31 from JustinGoheen/update-packaging
Update packaging
2 parents 409eb44 + 35a6103 commit fd4fb6c

25 files changed

+760
-599
lines changed

.gitignore

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,163 @@
1+
# Byte-compiled / optimized / DLL files
12
__pycache__/
2-
*.pth
3-
*bak.py
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/#use-with-ide
110+
.pdm.toml
111+
112+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113+
__pypackages__/
114+
115+
# Celery stuff
116+
celerybeat-schedule
117+
celerybeat.pid
118+
119+
# SageMath parsed files
120+
*.sage.py
121+
122+
# Environments
123+
.env
124+
.venv
125+
env/
126+
venv/
127+
ENV/
128+
env.bak/
129+
venv.bak/
130+
131+
# Spyder project settings
132+
.spyderproject
133+
.spyproject
134+
135+
# Rope project settings
136+
.ropeproject
137+
138+
# mkdocs documentation
139+
/site
140+
141+
# mypy
142+
.mypy_cache/
143+
.dmypy.json
144+
dmypy.json
145+
146+
# Pyre type checker
147+
.pyre/
148+
149+
# pytype static type analyzer
150+
.pytype/
151+
152+
# Cython debug symbols
153+
cython_debug/
154+
155+
# PyCharm
156+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158+
# and can be added to the global gitignore or merged into this file. For a more nuclear
159+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160+
#.idea/
161+
162+
# DS store
163+
.DS_Store

.pre-commit-config.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
default_language_version:
2+
python: python3
3+
4+
repos:
5+
- repo: https://github.com/pre-commit/pre-commit-hooks
6+
rev: v3.2.0
7+
hooks:
8+
- id: trailing-whitespace
9+
- id: end-of-file-fixer
10+
- id: check-yaml
11+
12+
- repo: https://github.com/PyCQA/isort
13+
rev: 5.10.1
14+
hooks:
15+
- id: isort
16+
name: Format imports
17+
18+
- repo: https://github.com/psf/black
19+
rev: 22.3.0
20+
hooks:
21+
- id: black
22+
name: Format code
23+
24+
- repo: https://github.com/asottile/blacken-docs
25+
rev: v1.12.1
26+
hooks:
27+
- id: blacken-docs
28+
args: [--line-length=120]
29+
additional_dependencies: [black==21.12b0]
30+
31+
- repo: https://github.com/charliermarsh/ruff-pre-commit
32+
rev: "v0.0.237"
33+
hooks:
34+
- id: ruff
35+
args: ["--fix"]

alpaca_finetuning_v1/engine_finetuning.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,36 @@
33
from typing import Iterable
44

55
import torch
6-
7-
import util.misc as misc
86
import util.lr_sched as lr_sched
7+
import util.misc as misc
98

109

10+
def train_one_epoch(
11+
model: torch.nn.Module,
12+
data_loader: Iterable,
13+
optimizer: torch.optim.Optimizer,
14+
device: torch.device,
15+
epoch: int,
16+
loss_scaler,
17+
log_writer=None,
18+
args=None,
19+
):
1120

12-
def train_one_epoch(model: torch.nn.Module,
13-
data_loader: Iterable, optimizer: torch.optim.Optimizer,
14-
device: torch.device, epoch: int, loss_scaler,
15-
log_writer=None,
16-
args=None):
17-
1821
model.train(True)
1922
metric_logger = misc.MetricLogger(delimiter=" ")
20-
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
21-
header = 'Epoch: [{}]'.format(epoch)
23+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
24+
header = "Epoch: [{}]".format(epoch)
2225
print_freq = 10
2326

2427
accum_iter = args.accum_iter
2528

2629
optimizer.zero_grad()
2730

2831
if log_writer is not None:
29-
print('log_dir: {}'.format(log_writer.log_dir))
30-
for data_iter_step, (examples, labels, example_mask) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
32+
print("log_dir: {}".format(log_writer.log_dir))
33+
for data_iter_step, (examples, labels, example_mask) in enumerate(
34+
metric_logger.log_every(data_loader, print_freq, header)
35+
):
3136
# we use a per iteration (instead of per epoch) lr scheduler
3237
if data_iter_step % accum_iter == 0:
3338
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
@@ -43,8 +48,7 @@ def train_one_epoch(model: torch.nn.Module,
4348

4449
loss /= accum_iter
4550

46-
loss_scaler(loss, optimizer, parameters=model.parameters(),
47-
update_grad=(data_iter_step + 1) % accum_iter == 0)
51+
loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0)
4852
if (data_iter_step + 1) % accum_iter == 0:
4953
optimizer.zero_grad()
5054

@@ -55,42 +59,49 @@ def train_one_epoch(model: torch.nn.Module,
5559
lr = optimizer.param_groups[0]["lr"]
5660
metric_logger.update(lr=lr)
5761

58-
loss_value_reduce = misc.all_reduce_mean(loss_value)
62+
misc.all_reduce_mean(loss_value)
5963
c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
6064

6165
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
62-
""" We use epoch_1000x as the x-axis in tensorboard.
66+
"""We use epoch_1000x as the x-axis in tensorboard.
6367
This calibrates different curves when batch size changes.
6468
"""
6569
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
66-
log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x)
67-
log_writer.add_scalar('lr', lr, epoch_1000x)
70+
log_writer.add_scalar("c_train_loss", c_loss_value_reduce, epoch_1000x)
71+
log_writer.add_scalar("lr", lr, epoch_1000x)
6872

6973
# gather the stats from all processes
7074
metric_logger.synchronize_between_processes()
7175
print("Averaged stats:", metric_logger)
7276
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
7377

7478

75-
def val_one_epoch(model: torch.nn.Module,
76-
data_loader: Iterable, optimizer: torch.optim.Optimizer,
77-
device: torch.device, epoch: int, loss_scaler,
78-
log_writer=None,
79-
args=None):
79+
def val_one_epoch(
80+
model: torch.nn.Module,
81+
data_loader: Iterable,
82+
optimizer: torch.optim.Optimizer,
83+
device: torch.device,
84+
epoch: int,
85+
loss_scaler,
86+
log_writer=None,
87+
args=None,
88+
):
8089
model.eval()
8190
metric_logger = misc.MetricLogger(delimiter=" ")
82-
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
83-
header = 'Epoch: [{}]'.format(epoch)
91+
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
92+
header = "Epoch: [{}]".format(epoch)
8493
print_freq = 10
8594

8695
accum_iter = args.accum_iter
8796

8897
if log_writer is not None:
89-
print('log_dir: {}'.format(log_writer.log_dir))
90-
for data_iter_step, (examples, labels, example_mask) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
98+
print("log_dir: {}".format(log_writer.log_dir))
99+
for data_iter_step, (examples, labels, example_mask) in enumerate(
100+
metric_logger.log_every(data_loader, print_freq, header)
101+
):
91102

92103
with torch.no_grad():
93-
c_loss = model(examples, labels)
104+
c_loss = model(examples, labels)
94105
loss = c_loss
95106
loss_value = loss.item()
96107

@@ -105,15 +116,15 @@ def val_one_epoch(model: torch.nn.Module,
105116
lr = optimizer.param_groups[0]["lr"]
106117
metric_logger.update(lr=lr)
107118

108-
loss_value_reduce = misc.all_reduce_mean(loss_value)
119+
misc.all_reduce_mean(loss_value)
109120
c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
110121
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
111-
""" We use epoch_1000x as the x-axis in tensorboard.
122+
"""We use epoch_1000x as the x-axis in tensorboard.
112123
This calibrates different curves when batch size changes.
113124
"""
114125
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
115-
log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x)
116-
log_writer.add_scalar('lr', lr, epoch_1000x)
126+
log_writer.add_scalar("c_train_loss", c_loss_value_reduce, epoch_1000x)
127+
log_writer.add_scalar("lr", lr, epoch_1000x)
117128

118129
# gather the stats from all processes
119130
metric_logger.synchronize_between_processes()
Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import torch
22

3-
4-
5-
model = torch.load('./checkpoint/checkpoint-4.pth', map_location='cpu')
3+
model = torch.load("./checkpoint/checkpoint-4.pth", map_location="cpu")
64
new_model = dict()
7-
weight_list = ['layers.' + str(i) + '.attention.gate' for i in range(32)]
8-
old_weight_list = ['layers.' + str(i) + '.attention.gate' for i in range(32)]
9-
weight_list = weight_list + ['adapter_query.weight']
5+
weight_list = ["layers." + str(i) + ".attention.gate" for i in range(32)]
6+
old_weight_list = ["layers." + str(i) + ".attention.gate" for i in range(32)]
7+
weight_list = weight_list + ["adapter_query.weight"]
108

119
print(weight_list)
12-
print(model['model']['adapter_query.weight'].shape)
10+
print(model["model"]["adapter_query.weight"].shape)
1311

1412
for i in range(len(weight_list)):
15-
new_model[weight_list[i]] = model['model'][weight_list[i]]
13+
new_model[weight_list[i]] = model["model"][weight_list[i]]
1614

17-
torch.save(new_model, 'adapter_adapter_len10_layer30_epoch5.pth')
15+
torch.save(new_model, "adapter_adapter_len10_layer30_epoch5.pth")

0 commit comments

Comments
 (0)