28
28
29
29
import string
30
30
import requests
31
+ import os
31
32
32
33
response = requests .get ('https://www.gutenberg.org/cache/epub/1497/pg1497.txt' )
33
34
data = response .text .split ('\n ' )
@@ -168,6 +169,11 @@ def tokenize_prepare_dataset(lines):
168
169
# In[ ]:
169
170
170
171
172
+ num_epochs = 200
173
+ # For custom epochs numbers from the environment
174
+ if "ITEX_NUM_EPOCHS" in os .environ :
175
+ num_epochs = int (os .environ .get ('ITEX_NUM_EPOCHS' ))
176
+
171
177
neuron_coef = 4
172
178
itex_lstm_model = Sequential ()
173
179
itex_lstm_model .add (Embedding (input_dim = vocab_size , output_dim = seq_length , input_length = seq_length ))
@@ -177,7 +183,7 @@ def tokenize_prepare_dataset(lines):
177
183
itex_lstm_model .add (Dense (units = vocab_size , activation = 'softmax' ))
178
184
itex_lstm_model .summary ()
179
185
itex_lstm_model .compile (optimizer = 'adam' , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
180
- itex_lstm_model .fit (x ,y , batch_size = 256 , epochs = 200 )
186
+ itex_lstm_model .fit (x ,y , batch_size = 256 , epochs = num_epochs )
181
187
182
188
183
189
# ## Compared to LSTM from Keras
@@ -201,6 +207,11 @@ def tokenize_prepare_dataset(lines):
201
207
seq_length = x .shape [1 ]
202
208
vocab_size = y .shape [1 ]
203
209
210
+ num_epochs = 20
211
+ # For custom epochs numbers
212
+ if "KERAS_NUM_EPOCHS" in os .environ :
213
+ num_epochs = int (os .environ .get ('KERAS_NUM_EPOCHS' ))
214
+
204
215
neuron_coef = 1
205
216
keras_lstm_model = Sequential ()
206
217
keras_lstm_model .add (Embedding (input_dim = vocab_size , output_dim = seq_length , input_length = seq_length ))
@@ -210,7 +221,7 @@ def tokenize_prepare_dataset(lines):
210
221
keras_lstm_model .add (Dense (units = vocab_size , activation = 'softmax' ))
211
222
keras_lstm_model .summary ()
212
223
keras_lstm_model .compile (optimizer = 'adam' , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
213
- keras_lstm_model .fit (x ,y , batch_size = 256 , epochs = 20 )
224
+ keras_lstm_model .fit (x ,y , batch_size = 256 , epochs = num_epochs )
214
225
215
226
216
227
# ## Generating text based on the input
@@ -276,4 +287,3 @@ def generate_text_seq(model, tokenizer, text_seq_length, seed_text, generated_wo
276
287
277
288
278
289
print ("[CODE_SAMPLE_COMPLETED_SUCCESFULLY]" )
279
-
0 commit comments