Skip to content

Commit de74fca

Browse files
authored
Update TextGenerationModelTraining.py
1 parent 47e5d87 commit de74fca

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

AI-and-Analytics/Features-and-Functionality/IntelTensorFlow_TextGeneration_with_LSTM/TextGenerationModelTraining.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import string
3030
import requests
31+
import os
3132

3233
response = requests.get('https://www.gutenberg.org/cache/epub/1497/pg1497.txt')
3334
data = response.text.split('\n')
@@ -168,6 +169,11 @@ def tokenize_prepare_dataset(lines):
168169
# In[ ]:
169170

170171

172+
num_epochs = 200
173+
# For custom epochs numbers from the environment
174+
if "ITEX_NUM_EPOCHS" in os.environ:
175+
num_epochs = int(os.environ.get('ITEX_NUM_EPOCHS'))
176+
171177
neuron_coef = 4
172178
itex_lstm_model = Sequential()
173179
itex_lstm_model.add(Embedding(input_dim=vocab_size, output_dim=seq_length, input_length=seq_length))
@@ -177,7 +183,7 @@ def tokenize_prepare_dataset(lines):
177183
itex_lstm_model.add(Dense(units=vocab_size, activation='softmax'))
178184
itex_lstm_model.summary()
179185
itex_lstm_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
180-
itex_lstm_model.fit(x,y, batch_size=256, epochs=200)
186+
itex_lstm_model.fit(x,y, batch_size=256, epochs=num_epochs)
181187

182188

183189
# ## Compared to LSTM from Keras
@@ -201,6 +207,11 @@ def tokenize_prepare_dataset(lines):
201207
seq_length = x.shape[1]
202208
vocab_size = y.shape[1]
203209

210+
num_epochs = 20
211+
# For custom epochs numbers
212+
if "KERAS_NUM_EPOCHS" in os.environ:
213+
num_epochs = int(os.environ.get('KERAS_NUM_EPOCHS'))
214+
204215
neuron_coef = 1
205216
keras_lstm_model = Sequential()
206217
keras_lstm_model.add(Embedding(input_dim=vocab_size, output_dim=seq_length, input_length=seq_length))
@@ -210,7 +221,7 @@ def tokenize_prepare_dataset(lines):
210221
keras_lstm_model.add(Dense(units=vocab_size, activation='softmax'))
211222
keras_lstm_model.summary()
212223
keras_lstm_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
213-
keras_lstm_model.fit(x,y, batch_size=256, epochs=20)
224+
keras_lstm_model.fit(x,y, batch_size=256, epochs=num_epochs)
214225

215226

216227
# ## Generating text based on the input
@@ -276,4 +287,3 @@ def generate_text_seq(model, tokenizer, text_seq_length, seed_text, generated_wo
276287

277288

278289
print("[CODE_SAMPLE_COMPLETED_SUCCESFULLY]")
279-

0 commit comments

Comments
 (0)