Skip to content

Commit 3fdcf59

Browse files
committed
creating weight tensor more efficient
1 parent 1c5f5bc commit 3fdcf59

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

extras/weight_loader.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,14 @@ def find_label(id:int,path="../split/splits.csv")-> str:
7373
# return should be a tuple of weigths matching the sequenece of the target and label tensor
7474
def create_data_weights(batchsize:int,dim:int,weights:dict[str,float],idents:tuple[int,...])-> torch.tensor:
7575
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76-
weight = torch.empty(batchsize,dim,device=device)
76+
weight = None
7777
index = 0
7878
for i in idents:
79-
w = weights[str(i)]
80-
for j in range(0,dim):
81-
weight[index][j] = float(w)
79+
w = torch.full((1,dim),float(weights[str(i)]),device=device)
80+
if weight == None:
81+
weight = w
82+
else:
83+
weight = torch.cat((weight,w),0)
8284
index = index + 1
8385
return weight
8486

0 commit comments

Comments
 (0)