Skip to content

Commit aeb2cab

Browse files
authored
Merge pull request #9 from neph1/update-v0.7.0
add rename lora script
2 parents 4b93efc + 9f4aaef commit aeb2cab

File tree

6 files changed

+89
-6
lines changed

6 files changed

+89
-6
lines changed

app.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import gradio as gr
44

5-
from config import global_config
65
from tabs.general_tab import GeneralTab
76
from tabs.prepare_tab import PrepareDatasetTab
8-
from tabs.tab import Tab
7+
from tabs.tool_tab import ToolTab
98
from tabs.training_tab import TrainingTab
109
from tabs.training_tab_legacy import LegacyTrainingTab
1110

@@ -22,28 +21,32 @@ def __init__(self):
2221
self.setup_views()
2322

2423
def setup_views(self):
25-
with gr.Blocks() as demo:
24+
with gr.Blocks() as app:
2625
gr.Markdown("### finetrainers ui")
2726

2827

2928
with gr.Tab("General Settings"):
3029
self.tabs['general'] = GeneralTab("General Settings", os.path.join(self.configs_path, "editor.yaml"))
3130
runtime_tab = gr.Tab("Trainer Settings")
31+
tools_tab = gr.Tab("Lora Tools")
32+
3233
prepare_tab = gr.Tab("Prepare dataset (Legacy)")
3334
runtime_tab_legacy = gr.Tab("Legacy Training Settings")
3435

3536
with runtime_tab:
3637
self.tabs['runtime'] = TrainingTab("Trainer Settings", os.path.join(self.configs_path, "config_template.yaml"), allow_load=True)
3738

38-
39+
with tools_tab:
40+
self.tabs['tools'] = ToolTab()
41+
3942
with prepare_tab:
4043
self.tabs['prepare'] = PrepareDatasetTab("Prepare dataset (Legacy)", os.path.join(self.configs_path, "prepare_template.yaml"), allow_load=True)
4144

4245
with runtime_tab_legacy:
4346
self.tabs['runtime'] = LegacyTrainingTab("Legacy CogvideoX Settings", os.path.join(self.configs_path, "config_template_legacy.yaml"), allow_load=True)
4447

4548

46-
demo.launch()
49+
app.launch()
4750

4851

4952
if __name__ == "__main__":

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
gradio
1+
gradio
2+
torch>=2.4.1

scripts/__init__.py

Whitespace-only changes.

scripts/common_io.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
import torch
3+
from safetensors.torch import load_file, save_file, safe_open
4+
5+
def load_state_dict(file_name, dtype):
6+
if os.path.splitext(file_name)[1] == ".safetensors":
7+
sd = load_file(file_name)
8+
with safe_open(file_name, framework="pt") as f:
9+
metadata = f.metadata()
10+
else:
11+
sd = torch.load(file_name, map_location="cpu")
12+
metadata = None
13+
14+
for key in list(sd.keys()):
15+
if type(sd[key]) == torch.Tensor:
16+
sd[key] = sd[key].to(dtype)
17+
18+
return sd, metadata
19+
20+
def save_to_file(file_name, state_dict, dtype, metadata):
21+
if dtype is not None:
22+
for key in list(state_dict.keys()):
23+
if type(state_dict[key]) == torch.Tensor:
24+
state_dict[key] = state_dict[key].to(dtype)
25+
if os.path.splitext(file_name)[1] == ".safetensors":
26+
save_file(state_dict, file_name, metadata)
27+
else:
28+
torch.save(state_dict, file_name)

scripts/rename_keys.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import argparse
2+
import torch
3+
4+
from scripts.common_io import load_state_dict, save_to_file
5+
6+
7+
def rename_keys(file, outfile: str)-> bool:
8+
sd, metadata = load_state_dict(file, torch.float32)
9+
10+
keys_to_normalize = [key for key in sd.keys()]
11+
values_to_normalize = [sd[key].to(torch.float32) for key in keys_to_normalize]
12+
new_sd = dict()
13+
for key, value in zip(keys_to_normalize, values_to_normalize):
14+
new_sd[key.replace("transformer.", "")] = value
15+
16+
save_to_file(outfile, new_sd, torch.float16, metadata)
17+
return True
18+
19+
def setup_parser() -> argparse.ArgumentParser:
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("file", type=str, help="lora model to modify")
22+
parser.add_argument("-o", "--outfile", type=str, help="file to save to")
23+
return parser
24+
25+
26+
if __name__ == "__main__":
27+
parser = setup_parser()
28+
args = parser.parse_args()
29+
rename_keys(args.file, args.outfile)

tabs/tool_tab.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import gradio as gr
2+
3+
from scripts import rename_keys
4+
5+
class ToolTab:
6+
7+
def __init__(self):
8+
gr.Markdown("Lora Tools")
9+
self._build_rename()
10+
11+
def _build_rename(self):
12+
gr.Markdown("Rename Transformer Keys")
13+
rename_in_file = gr.Textbox(value="", label="Lora to rename (full path)", lines=1)
14+
rename_out_file = gr.Textbox(value="", label="New name (full path)", lines=1)
15+
rename_result = gr.Textbox(value="", label="Result", lines=1)
16+
rename_button = gr.Button("Rename", key='rename')
17+
rename_button.click(self.rename_file, inputs=[rename_in_file, rename_out_file], outputs=[rename_result])
18+
19+
def rename_file(self, lora_name, new_name):
20+
result = rename_keys.rename_keys(lora_name, new_name)
21+
if result:
22+
return "Renamed successfully"

0 commit comments

Comments
 (0)