Skip to content

Commit 2fb127c

Browse files
committed
added support for negative sampling
1 parent 4150f18 commit 2fb127c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

mitdeeplearning/lab3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, data_path, batch_size, training=True):
7676
self.train_inds = np.concatenate((self.pos_train_inds, self.neg_train_inds))
7777
self.batch_size = batch_size
7878
self.p_pos = np.ones(self.pos_train_inds.shape) / len(self.pos_train_inds)
79+
self.p_neg = np.ones(self.neg_train_inds.shape) / len(self.neg_train_inds)
7980

8081
def get_train_size(self):
8182
return self.pos_train_inds.shape[0] + self.neg_train_inds.shape[0]
@@ -88,7 +89,7 @@ def __getitem__(self, index):
8889
self.pos_train_inds, size=self.batch_size // 2, replace=False, p=self.p_pos
8990
)
9091
selected_neg_inds = np.random.choice(
91-
self.neg_train_inds, size=self.batch_size // 2, replace=False
92+
self.neg_train_inds, size=self.batch_size // 2, replace=False, p = self.p_neg
9293
)
9394
selected_inds = np.concatenate((selected_pos_inds, selected_neg_inds))
9495

0 commit comments

Comments
 (0)