Skip to content

Commit 8cd4087

Browse files
authored
Improve performance (#67)
1 parent fe4393e commit 8cd4087

File tree

1 file changed

+2
-2
lines changed
  • server/embedding_as_service/text/xlnet/models

1 file changed

+2
-2
lines changed

server/embedding_as_service/text/xlnet/models/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,10 @@ def get_cache_fn(mem_len):
204204
def cache_fn(batch_size):
205205
mems = []
206206
if FLAGS.mem_len > 0:
207-
for _ in range(FLAGS.n_layer):
208-
zeros = tf.zeros(
207+
zeros = tf.zeros(
209208
[mem_len, batch_size, FLAGS.d_model],
210209
dtype=tf_float)
210+
for _ in range(FLAGS.n_layer):
211211
mems.append(zeros)
212212

213213
return mems

0 commit comments

Comments
 (0)