Skip to content

A self-defined AbstractMatrix's subtype does not respect a self-defined Base.zero for its elements, which is a subtype of Number, when performing matrix multiplication #1334

@waylonwh

Description

@waylonwh

Consider the Viterbi semiring (max-times semiring), in which the binary operations are defined as:
– addition: max
– multiplication: *
with -Inf as the additive identity and 1 as the multiplicative identity.

A minimal working example is given below:

struct ViterbiSymbol{T<:Number} <: Number
    v::T
end

Base.zero(::Type{ViterbiSymbol}) = ViterbiSymbol{Float64}(-Inf)
Base.one(::Type{ViterbiSymbol{T}}) where {T} = ViterbiSymbol{T}(1)

Base.:+(a::ViterbiSymbol{T}, b::ViterbiSymbol{T}) where {T} = ViterbiSymbol{T}(max(a.v, b.v))
Base.:*(a::ViterbiSymbol{T}, b::ViterbiSymbol{T}) where {T} = ViterbiSymbol{T}(a.v * b.v)

Base.promote_rule(::Type{ViterbiSymbol{Int}}, ::Type{ViterbiSymbol{Float64}}) = ViterbiSymbol{Float64}
Base.convert(::Type{ViterbiSymbol{Float64}}, v::ViterbiSymbol{Int}) = ViterbiSymbol{Float64}(float(v.v))

struct ViterbiMatrix{T<:Number} <: AbstractMatrix{ViterbiSymbol{T}}
    m::Matrix{ViterbiSymbol{T}}
end
function ViterbiMatrix(e::Union{typeof(zero),typeof(one)}, size::Tuple{Int,Int})
    if e === zero
        m::Matrix{ViterbiSymbol{Float64}} = fill(e(ViterbiSymbol), size) # all "0"
    else # e is one, all "0" except diagonal being "1"
        m = fill(zero(ViterbiSymbol), size[2], size[2])
        foreach(i -> m[i, i] = e(ViterbiSymbol), 1:size[2])
    end
    return ViterbiMatrix{Float64}(m)
end

Base.size(a::ViterbiMatrix) = Base.size(a.m)
Base.getindex(a::ViterbiMatrix, i::Vararg{Int,2}) = Base.getindex(a.m, i[1], i[2])
Base.:+(a::ViterbiMatrix, b::ViterbiMatrix) = ViterbiMatrix(a.m + b.m)
Base.:*(a::ViterbiMatrix, b::ViterbiMatrix) = ViterbiMatrix(a.m * b.m)
Base.typed_hvcat(
    ::Type{ViterbiSymbol}, rows::Tuple{Vararg{Int}}, xs::Number...
) = ViterbiMatrix(Matrix{ViterbiSymbol{eltype(rows)}}(reshape(collect(xs), rows)'))


A = ViterbiSymbol[1 2; 3 4];
O = ViterbiMatrix(zero, size(A))

display(A)
display(O)
display(A * O) # expected: all -Inf, got all 0.0
display(A * O == O) # expected: true, got false

If my understanding is correct, the function _rmul_or_fill! is called to create an all-zero matrix to fill in during multiplication. This function calls zero(eltype(C)), but it seems to use the method zero(::Type{Number}) -> 0 instead of the custom zero method defined above, which returns ViterbiSymbol(-Inf).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions