Skip to content

Commit 01a176c

Browse files
committed
first weight setter
1 parent 016b5ea commit 01a176c

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

extras/weight_loader.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import csv
2+
import torch
3+
4+
5+
#inint weights in a csv file
6+
def init_weights(path="/home/programmer/Bachelorarbeit/weights/first_it.csv",path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
7+
with open(path_to_split, 'r') as csvfile:
8+
with open(path, 'w') as to_file:
9+
fieldnames = ['idents','label','weights']
10+
writer = csv.writer(to_file)
11+
writer.writerow(fieldnames)
12+
reader = csv.reader(csvfile)
13+
weight = 1 / get_size(path_to_split)
14+
for row in reader:
15+
if row[1] == "train" or row[1] == "validation":
16+
#print(type(row[0]))
17+
writer.writerow([int(row[0]),row[1],weight])
18+
19+
def mock_init_weights(path="/home/programmer/Bachelorarbeit/weights/first_it.csv",path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
20+
with open(path_to_split, 'r') as csvfile:
21+
with open(path, 'w') as to_file:
22+
fieldnames = ['idents','label','weights']
23+
writer = csv.writer(to_file)
24+
writer.writerow(fieldnames)
25+
reader = csv.reader(csvfile)
26+
weight = 1
27+
for row in reader:
28+
if row[1] == "train" or row[1] == "validation":
29+
writer.writerow([int(row[0]),row[1],weight])
30+
weight = weight + 1
31+
32+
#check the size of a csv file given a filter for the second object
33+
# assumes csv file has a header
34+
def get_size(path="/home/programmer/Bachelorarbeit/split/splits.csv",filter=["train","validation"]) -> int:
35+
with open(path,'r') as file:
36+
reader = csv.reader(file)
37+
size = -1
38+
for row in reader:
39+
if row[1] in filter:
40+
size = size + 1
41+
return size
42+
#get a dictory with the ids and weights of the data points
43+
def get_weights(idents:tuple[int,...],path="/home/programmer/Bachelorarbeit/weights/first_it.csv")-> dict[str,float]:
44+
value = dict()
45+
for i in idents:
46+
weight = find_weight(path,i)
47+
value.update({str(i):weight})
48+
return value
49+
50+
51+
52+
#finds the weight for a specific datapoint
53+
def find_weight(path:str,ident:int)-> float:
54+
with open(path,'r') as file:
55+
reader = csv.reader(file)
56+
for row in reader:
57+
if row[0] == str(ident):
58+
return float(row[2])
59+
60+
61+
print(f"{ident} is not in file ")
62+
63+
64+
def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
65+
weight = torch.empty(batchsize,dim)
66+
index = 0
67+
for i in idents:
68+
w = weights[str(i)]
69+
for j in range(0,dim):
70+
weight[index][j] = w
71+
index = index + 1
72+
return weight
73+
74+
75+
76+
def testing():
77+
print("hello world")
78+
79+
80+
81+
82+
83+
84+
85+
86+
def create_class_weights()-> torch.tensor:
87+
pass
88+
89+
# mock_init_weights()
90+
# print(get_weights((233713,51990)))
91+

0 commit comments

Comments
 (0)