Skip to content

generic_matmul! hit in back! because type-promotion in activation function #613

@oxinabox

Description

@oxinabox

Sometimes generic_matmul! is hit in back!
For examopole adding a leak too unit can be done
by writing an activation function like

    leaky_relu6(x) = 0.01x + clamp(x, 0, 6)

And this is well and good, of x is a Float64.
But if x is a Float32 this will trigger a type-promotion.
Which is bad, because the user almost certainly did not intend the type promotion.
But worse,
it means rather than hitting fast BLAS, we fall back to slow generic_matmul!.

Here is a MWE:

function flux_model()
    return Chain(
#        Dense(1280, 64, x->0.1f0x),    # Uncomment one of these lines
#        Dense(1280, 64, x->0.1e0x),   # Uncomment one these lines
        Dense(64, 1),
    )
end

function demo_flux()
    mdl = flux_model()
    features = rand(Float32, (1280, 1000))
    
    Flux.train!(
        params(mdl),
        [(features,)],
        Flux.ADAM()
    ) do xs
        sum(mdl(xs))
    end
end

Time if it has to promote: @time demo_flux()

0.143774 seconds (607 allocations: 19.635 MiB)

Time normally: @time demo_flux()

0.016475 seconds (568 allocations: 13.218 MiB, 47.67% gc time)

That is a 10x time diifference, and it scales up as your matrix sizes scale up.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions