Skip to content

Commit 7ab20d0

Browse files
Add mixed precision support for glue script (#608)
1 parent 48ccbf1 commit 7ab20d0

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/glue_benchmark/glue.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,12 @@ def connect_to_tpu(tpu_name):
216216
def main(_):
217217
if FLAGS.tpu_name:
218218
strategy = connect_to_tpu(FLAGS.tpu_name)
219+
policy = keras.mixed_precision.Policy("mixed_bfloat16")
219220
else:
220-
# Use default strategy is not using TPU.
221+
# Use default strategy if not using TPU.
221222
strategy = tf.distribute.get_strategy()
223+
policy = keras.mixed_precision.Policy("mixed_float16")
224+
keras.mixed_precision.set_global_policy(policy)
222225

223226
train_ds, test_ds, val_ds, idx_order = load_data(FLAGS.task_name)
224227
# ----- Custom code block starts -----

0 commit comments

Comments
 (0)