Skip to content

Commit 9a20c3b

Browse files
Format code (#932)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 2969059 commit 9a20c3b

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

infer_cli.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,12 @@ def get_vc(model_path):
234234
version = cpt.get("version", "v1")
235235
if version == "v1":
236236
if if_f0 == 1:
237-
net_g = SynthesizerTrnMs256NSFsid(
238-
*cpt["config"], is_half=is_half
239-
)
237+
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
240238
else:
241239
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
242240
elif version == "v2":
243241
if if_f0 == 1:
244-
net_g = SynthesizerTrnMs768NSFsid(
245-
*cpt["config"], is_half=is_half
246-
)
242+
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half)
247243
else:
248244
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
249245
del net_g.enc_q

tools/calc_rvc_model_similarity.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# This code references https://huggingface.co/JosephusCheung/ASimilarityCalculatior/blob/main/qwerty.py
22
# Fill in the path of the model to be queried and the root directory of the reference models, and this script will return the similarity between the model to be queried and all reference models.
3-
import sys,os
3+
import sys, os
44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
77

8+
89
def cal_cross_attn(to_q, to_k, to_v, rand_input):
910
hidden_dim, embed_dim = to_q.shape
1011
attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
@@ -16,41 +17,50 @@ def cal_cross_attn(to_q, to_k, to_v, rand_input):
1617

1718
return torch.einsum(
1819
"ik, jk -> ik",
19-
F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1),
20-
attn_to_v(rand_input)
20+
F.softmax(
21+
torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)),
22+
dim=-1,
23+
),
24+
attn_to_v(rand_input),
2125
)
2226

27+
2328
def model_hash(filename):
2429
try:
2530
with open(filename, "rb") as file:
2631
import hashlib
32+
2733
m = hashlib.sha256()
2834

2935
file.seek(0x100000)
3036
m.update(file.read(0x10000))
3137
return m.hexdigest()[0:8]
3238
except FileNotFoundError:
33-
return 'NOFILE'
39+
return "NOFILE"
40+
3441

3542
def eval(model, n, input):
3643
qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight"
3744
uk = f"enc_p.encoder.attn_layers.{n}.conv_k.weight"
3845
vk = f"enc_p.encoder.attn_layers.{n}.conv_v.weight"
39-
atoq, atok, atov = model[qk][:,:,0], model[uk][:,:,0], model[vk][:,:,0]
46+
atoq, atok, atov = model[qk][:, :, 0], model[uk][:, :, 0], model[vk][:, :, 0]
4047

4148
attn = cal_cross_attn(atoq, atok, atov, input)
4249
return attn
4350

44-
def main(path,root):
51+
52+
def main(path, root):
4553
torch.manual_seed(114514)
4654
model_a = torch.load(path, map_location="cpu")["weight"]
4755

48-
print("query:\t\t%s\t%s"%(path,model_hash(path)))
56+
print("query:\t\t%s\t%s" % (path, model_hash(path)))
4957

5058
map_attn_a = {}
5159
map_rand_input = {}
5260
for n in range(6):
53-
hidden_dim, embed_dim,_ = model_a[f"enc_p.encoder.attn_layers.{n}.conv_v.weight"].shape
61+
hidden_dim, embed_dim, _ = model_a[
62+
f"enc_p.encoder.attn_layers.{n}.conv_v.weight"
63+
].shape
5464
rand_input = torch.randn([embed_dim, hidden_dim])
5565

5666
map_attn_a[n] = eval(model_a, n, rand_input)
@@ -59,7 +69,7 @@ def main(path,root):
5969
del model_a
6070

6171
for name in sorted(list(os.listdir(root))):
62-
path="%s/%s"%(root,name)
72+
path = "%s/%s" % (root, name)
6373
model_b = torch.load(path, map_location="cpu")["weight"]
6474

6575
sims = []
@@ -70,9 +80,13 @@ def main(path,root):
7080
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
7181
sims.append(sim)
7282

73-
print("reference:\t%s\t%s\t%s"%(path,model_hash(path),f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%"))
83+
print(
84+
"reference:\t%s\t%s\t%s"
85+
% (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%")
86+
)
87+
7488

7589
if __name__ == "__main__":
76-
query_path=r"weights\mi v3.pth"
77-
reference_root=r"weights"
78-
main(query_path,reference_root)
90+
query_path = r"weights\mi v3.pth"
91+
reference_root = r"weights"
92+
main(query_path, reference_root)

0 commit comments

Comments
 (0)