Skip to content

Commit 4d67a5e

Browse files
authored
Merge pull request #26 from keroro824/tharun
fixed bugs in sampled softmax code
2 parents 04cba83 + 5a14128 commit 4d67a5e

File tree

3 files changed

+53
-17
lines changed

3 files changed

+53
-17
lines changed

python_examples/config.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,22 @@
33
class config:
44
data_path_train = '../dataset/Amazon/amazon_train.txt'
55
data_path_test = '../dataset/Amazon/amazon_test.txt'
6-
data_path = '../dataset/Amazon/amazon_train.txt'
7-
GPUs = '' # empty string uses only CPU
8-
num_threads = 96 # Only used when GPUs is empty string
6+
GPUs = '0' # empty string uses only CPU
7+
num_threads = 44 # Only used when GPUs is empty string
98
lr = 0.0001
109
###
1110
feature_dim = 135909
1211
n_classes = 670091
1312
n_train = 490449
1413
n_test = 153025
15-
n_epochs = 20
14+
n_epochs = 2
1615
batch_size = 128
1716
hidden_dim = 128
1817
###
19-
log_file = 'log'
20-
18+
log_file = 'log_amz_ss'
2119
### for sampled softmax
22-
n_samples = 670091//10
23-
max_label = 100
20+
n_samples = n_classes//10
21+
### choose the max_labels per training sample.
22+
### If the number of true labels is < max_label,
23+
### we will pad the rest of them with a dummy class (see data_generator_ss in util.py)
24+
max_label = 1

python_examples/example_sampled_softmax.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from config import config
99
from itertools import islice
1010
#from scipy.sparse import csr_matrix
11-
from util import data_generator, data_generator_tst
11+
from util import data_generator_ss, data_generator_tst
1212

1313
## Training Params
1414
def main():
@@ -36,23 +36,28 @@ def main():
3636
x_idxs = tf.placeholder(tf.int64, shape=[None,2])
3737
x_vals = tf.placeholder(tf.float32, shape=[None])
3838
x = tf.SparseTensor(x_idxs, x_vals, [batch_size,feature_dim])
39-
y = tf.placeholder(tf.float32, shape=[None,n_classes])
39+
y = tf.placeholder(tf.float32, shape=[None,max_label])
4040
#
4141
W1 = tf.Variable(tf.truncated_normal([feature_dim,hidden_dim], stddev=2.0/math.sqrt(feature_dim+hidden_dim)))
4242
b1 = tf.Variable(tf.truncated_normal([hidden_dim], stddev=2.0/math.sqrt(feature_dim+hidden_dim)))
4343
layer_1 = tf.nn.relu(tf.sparse_tensor_dense_matmul(x,W1)+b1)
4444
#
45-
W2 = tf.Variable(tf.truncated_normal([hidden_dim,n_classes], stddev=2.0/math.sqrt(hidden_dim+n_classes)))
46-
b2 = tf.Variable(tf.truncated_normal([n_classes], stddev=2.0/math.sqrt(n_classes+hidden_dim)))
47-
logits = tf.matmul(layer_1,W2)+b2
45+
if max_label>1: # an extra node for padding a dummy class
46+
W2 = tf.Variable(tf.truncated_normal([hidden_dim,n_classes+1], stddev=2.0/math.sqrt(hidden_dim+n_classes)))
47+
b2 = tf.Variable(tf.truncated_normal([n_classes+1], stddev=2.0/math.sqrt(n_classes+hidden_dim)))
48+
logits = tf.matmul(layer_1,W2[:,:-1])+b2[:-1]
49+
else:
50+
W2 = tf.Variable(tf.truncated_normal([hidden_dim,n_classes], stddev=2.0/math.sqrt(hidden_dim+n_classes)))
51+
b2 = tf.Variable(tf.truncated_normal([n_classes], stddev=2.0/math.sqrt(n_classes+hidden_dim)))
52+
logits = tf.matmul(layer_1,W2)+b2
4853
#
4954
k=1
5055
if k==1:
5156
top_idxs = tf.argmax(logits, axis=1)
5257
else:
5358
top_idxs = tf.nn.top_k(logits, k=k, sorted=False)[1]
5459
#
55-
loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(tf.transpose(W2),b2,tf.reshape(y,[-1,max_label]),layer_1,n_samples,n_classes,remove_accidental_hits=False, num_true=max_label,partition_strategy='div'))
60+
loss = tf.reduce_mean(tf.nn.sampled_softmax_loss(tf.transpose(W2), b2, y, layer_1, n_samples, n_classes, remove_accidental_hits=False, num_true=max_label, partition_strategy='div'))
5661
#
5762
train_step = tf.train.AdamOptimizer(lr).minimize(loss)
5863
#
@@ -65,7 +70,7 @@ def main():
6570
sess = tf.Session(config=Config)
6671
sess.run(tf.global_variables_initializer())
6772
#
68-
training_data_generator = data_generator(train_files, batch_size, n_classes)
73+
training_data_generator = data_generator_ss(train_files, batch_size, n_classes, max_label)
6974
steps_per_epoch = n_train//batch_size
7075
n_steps = n_epochs*steps_per_epoch
7176
n_check = 500
@@ -94,7 +99,7 @@ def main():
9499
sess.run(train_step, feed_dict={x_idxs:idxs_batch, x_vals:vals_batch, y:labels_batch})
95100
if i%steps_per_epoch==steps_per_epoch-1:
96101
total_time+=time.time()-begin_time
97-
print('Finished ',i,' steps. Time elapsed for last 100 batches = ',time.time()-begin_time)
102+
print('Finished ',i,' steps. Time elapsed for last', i%n_check, 'batches = ',time.time()-begin_time)
98103
n_steps_val = n_test//batch_size
99104
test_data_generator = data_generator_tst(test_files, batch_size)
100105
num_batches = 0

python_examples/util.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from itertools import islice
22
import numpy as np
3-
3+
from config import config
44

55
def data_generator(files, batch_size, n_classes):
66
while 1:
@@ -34,6 +34,36 @@ def data_generator(files, batch_size, n_classes):
3434
lines = []
3535
yield (idxs, vals, y_batch)
3636

37+
def data_generator_ss(files, batch_size, n_classes, max_label):
38+
while 1:
39+
lines = []
40+
for file in files:
41+
with open(file,'r',encoding='utf-8') as f:
42+
header = f.readline() # ignore the header
43+
while True:
44+
temp = len(lines)
45+
lines += list(islice(f,batch_size-temp))
46+
if len(lines)!=batch_size:
47+
break
48+
idxs = []
49+
vals = []
50+
##
51+
y_batch = [None for i in range(len(lines))]
52+
count = 0
53+
for line in lines:
54+
itms = line.strip().split(' ')
55+
##
56+
y_batch[count] = [int(itm) for itm in itms[0].split(',')]
57+
if max_label>=len(y_batch[count]): #
58+
y_batch[count] += [n_classes for i in range(max_label-len(y_batch[count]))]
59+
else:
60+
y_batch[count] = np.random.choice(y_batch[count], max_label, replace=False)
61+
##
62+
idxs += [(count,int(itm.split(':')[0])) for itm in itms[1:]]
63+
vals += [float(itm.split(':')[1]) for itm in itms[1:]]
64+
count += 1
65+
lines = []
66+
yield (idxs, vals, y_batch)
3767

3868
def data_generator_tst(files, batch_size):
3969
while 1:

0 commit comments

Comments
 (0)