forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 97
Open
Labels
Description
Issue type
Performance
Have you reproduced the bug with TensorFlow Nightly?
No
Source
binary
TensorFlow version
2.18.1
Custom code
No
OS platform and distribution
Ubuntu 24.04.2 LTS (WSL2)
Mobile device
No response
Python version
3.12
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
Setting mixed_precision.set_global_policy('mixed_float16') results in lower performance on a Radeon 9070 XT (gfx1201).
Standalone code to reproduce the issue
The following code runs at roughly ~210ms/step without mixed precision, but this increases to 3s/step with it enabled:
import numpy as np
import tensorflow as tf
# tf.keras.mixed_precision.set_global_policy("mixed_float16")
filters = 64
convs = 16
act = "relu"
kernel_size = 3
class DepthToSpace(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__()
self.block_size = 2
def call(self, input):
x = tf.nn.depth_to_space(input, self.block_size)
return x
inputs = tf.keras.Input(shape=(None, None, 1))
conv0 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')(inputs)
x = conv0
for _ in range(convs):
x = tf.keras.layers.Conv2D(filters, kernel_size, padding='same', activation=act)(x)
conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')(x)
mix = tf.keras.layers.Add()([conv1, conv0])
features = tf.keras.layers.Conv2D(4, kernel_size, padding='same')(mix)
outputs = DepthToSpace()(features)
model = tf.keras.Model(inputs, outputs)
model.summary()
def synthetic_generator(batch_size):
while True:
ref_batch = np.random.rand(batch_size, 256, 256, 1).astype(np.float32) # NHWC
in_batch = tf.image.resize(ref_batch, [128, 128], method='bilinear').numpy() # NHWC
yield in_batch, ref_batch
batch_size = 8
steps_per_epoch = 100
train_gen = synthetic_generator(batch_size)
model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=2.5e-5), loss=tf.keras.losses.MeanAbsoluteError())
model.fit(train_gen, steps_per_epoch=steps_per_epoch, epochs=1, verbose=1)