Skip to content

Mixed precision results in lower performance on RDNA4 #3054

@Artoriuz

Description

@Artoriuz

Issue type

Performance

Have you reproduced the bug with TensorFlow Nightly?

No

Source

binary

TensorFlow version

2.18.1

Custom code

No

OS platform and distribution

Ubuntu 24.04.2 LTS (WSL2)

Mobile device

No response

Python version

3.12

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

Setting mixed_precision.set_global_policy('mixed_float16') results in lower performance on a Radeon 9070 XT (gfx1201).

Standalone code to reproduce the issue

The following code runs at roughly ~210ms/step without mixed precision, but this increases to 3s/step with it enabled:

import numpy as np
import tensorflow as tf

# tf.keras.mixed_precision.set_global_policy("mixed_float16")

filters = 64
convs = 16
act = "relu"
kernel_size = 3

class DepthToSpace(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.block_size = 2

    def call(self, input):
        x = tf.nn.depth_to_space(input, self.block_size)
        return x

inputs = tf.keras.Input(shape=(None, None, 1))
conv0 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')(inputs)

x = conv0
for _ in range(convs):
    x = tf.keras.layers.Conv2D(filters, kernel_size, padding='same', activation=act)(x)

conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')(x)
mix = tf.keras.layers.Add()([conv1, conv0])
features = tf.keras.layers.Conv2D(4, kernel_size, padding='same')(mix)
outputs = DepthToSpace()(features)

model = tf.keras.Model(inputs, outputs)
model.summary()

def synthetic_generator(batch_size):
    while True:
        ref_batch = np.random.rand(batch_size, 256, 256, 1).astype(np.float32) # NHWC
        in_batch = tf.image.resize(ref_batch, [128, 128], method='bilinear').numpy() # NHWC
        yield in_batch, ref_batch

batch_size = 8
steps_per_epoch = 100
train_gen = synthetic_generator(batch_size)

model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=2.5e-5), loss=tf.keras.losses.MeanAbsoluteError())
model.fit(train_gen, steps_per_epoch=steps_per_epoch, epochs=1, verbose=1)

Relevant log output

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions