11import csv
22import torch
3+ import os
34
45
56#inint weights in a csv file
6- def init_weights (path = "/home/programmer/Bachelorarbeit/weights/first_it.csv" ,path_to_split = "/home/programmer/Bachelorarbeit/split/splits.csv" ):
7- with open (path_to_split , 'r' ) as csvfile :
8- with open (path , 'w' ) as to_file :
9- fieldnames = ['idents' ,'label' ,'weights' ]
10- writer = csv .writer (to_file )
11- writer .writerow (fieldnames )
12- reader = csv .reader (csvfile )
13- weight = 1 / get_size (path_to_split )
14- for row in reader :
15- if row [1 ] == "train" or row [1 ] == "validation" :
16- #print(type(row[0]))
17- writer .writerow ([int (row [0 ]),row [1 ],weight ])
18-
19- def mock_init_weights (path = "/home/programmer/Bachelorarbeit/weights/first_it.csv" ,path_to_split = "/home/programmer/Bachelorarbeit/split/splits.csv" ):
7+ def init_weights (path = "../weights/first_it.csv" ,path_to_split = "../split/splits.csv" ):
8+ if not os .path .exists ("../weights/first_it.csv" ):
9+ with open (path_to_split , 'r' ) as csvfile :
10+ with open (path , 'w' ) as to_file :
11+ fieldnames = ['idents' ,'label' ,'weights' ]
12+ writer = csv .writer (to_file )
13+ writer .writerow (fieldnames )
14+ reader = csv .reader (csvfile )
15+ weight = 1 / get_size (path_to_split )
16+ for row in reader :
17+ if row [1 ] == "train" or row [1 ] == "validation" :
18+ #print(type(row[0]))
19+ writer .writerow ([int (row [0 ]),row [1 ],weight ])
20+
21+ def mock_init_weights (path = "../weights/first_it.csv" ,path_to_split = "../split/splits.csv" ):
2022 with open (path_to_split , 'r' ) as csvfile :
2123 with open (path , 'w' ) as to_file :
2224 fieldnames = ['idents' ,'label' ,'weights' ]
@@ -31,7 +33,7 @@ def mock_init_weights(path="/home/programmer/Bachelorarbeit/weights/first_it.csv
3133
3234#check the size of a csv file given a filter for the second object
3335# assumes csv file has a header
34- def get_size (path = "/home/programmer/Bachelorarbeit /split/splits.csv" ,filter = ["train" ]) -> int :
36+ def get_size (path = ".. /split/splits.csv" ,filter = ["train" ]) -> int :
3537 with open (path ,'r' ) as file :
3638 reader = csv .reader (file )
3739 size = - 1
@@ -40,15 +42,13 @@ def get_size(path="/home/programmer/Bachelorarbeit/split/splits.csv",filter=["tr
4042 size = size + 1
4143 return size
4244#get a dictory with the ids and weights of the data points
43- def get_weights (idents :tuple [int ,...],path = "/home/programmer/Bachelorarbeit /weights/first_it.csv" )-> dict [str ,float ]:
45+ def get_weights (idents :tuple [int ,...],path = ".. /weights/first_it.csv" )-> dict [str ,float ]:
4446 value = dict ()
4547 for i in idents :
4648 weight = find_weight (path ,i )
4749 value .update ({str (i ):weight })
4850 return value
4951
50-
51-
5252#finds the weight for a specific datapoint
5353def find_weight (path :str ,ident :int )-> float :
5454 with open (path ,'r' ) as file :
@@ -60,7 +60,7 @@ def find_weight(path:str,ident:int)-> float:
6060 label = find_label (id = ident )
6161 print (f"{ ident } is not in file with { label } " )
6262
63- def find_label (id :int ,path = "/home/programmer/Bachelorarbeit /split/splits.csv" )-> str :
63+ def find_label (id :int ,path = ".. /split/splits.csv" )-> str :
6464 with open (path ,'r' ) as file :
6565 reader = csv .reader (file )
6666 for row in reader :
@@ -81,8 +81,6 @@ def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tup
8181 index = index + 1
8282 return weight
8383
84-
85-
8684def testing ():
8785 print ("hello world" )
8886
@@ -99,6 +97,6 @@ def create_weight_tensor(weight:float)-> torch.tensor:
9997def create_class_weights ()-> torch .tensor :
10098 pass
10199
102- mock_init_weights ()
100+ # mock_init_weights()
103101# print(get_weights((233713,51990)))
104102
0 commit comments