We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 48ccbf1 commit 7ab20d0Copy full SHA for 7ab20d0
examples/glue_benchmark/glue.py
@@ -216,9 +216,12 @@ def connect_to_tpu(tpu_name):
216
def main(_):
217
if FLAGS.tpu_name:
218
strategy = connect_to_tpu(FLAGS.tpu_name)
219
+ policy = keras.mixed_precision.Policy("mixed_bfloat16")
220
else:
- # Use default strategy is not using TPU.
221
+ # Use default strategy if not using TPU.
222
strategy = tf.distribute.get_strategy()
223
+ policy = keras.mixed_precision.Policy("mixed_float16")
224
+ keras.mixed_precision.set_global_policy(policy)
225
226
train_ds, test_ds, val_ds, idx_order = load_data(FLAGS.task_name)
227
# ----- Custom code block starts -----
0 commit comments