diff --git a/src/ndarray/binary.jl b/src/ndarray/binary.jl index 10974206..959c39bc 100644 --- a/src/ndarray/binary.jl +++ b/src/ndarray/binary.jl @@ -129,11 +129,21 @@ function Base.:(*)(rhs1::NDArray{Bool,2}, rhs2::NDArray{Bool,2}) ) end -function Base.:(*)(rhs1::NDArray{<:Integer,2}, rhs2::NDArray{<:Integer,2}) - #* this is a stupid..... - throw( - ArgumentError("cuNumeric.jl does not support matrix multiplication of two Integer arrays") - ) +function Base.:(*)(rhs1::NDArray{A,2}, rhs2::NDArray{B,2}) where {A<:Integer, B<:Integer} + + size(rhs1, 2) == size(rhs2, 1) || throw(DimensionMismatch("Matrix dimensions incompatible: $(size(rhs1)) × $(size(rhs2))")) + + ResultType = __my_promote_type(A, B) + IntermediateType = Float64 + + A_float = cuNumeric.as_type(rhs1, IntermediateType) + B_float = cuNumeric.as_type(rhs2, IntermediateType) + + C_float = A_float * B_float + C_int = cuNumeric.as_type(C_float, ResultType) + + return C_int + end @doc""" diff --git a/test/tests/gemm.jl b/test/tests/gemm.jl index 57fdc12b..1f85392a 100644 --- a/test/tests/gemm.jl +++ b/test/tests/gemm.jl @@ -36,10 +36,9 @@ function gemm(N, M, T, max_diff) if T <: Integer a = cuNumeric.ones(Int32, 5, 5) a_jl = ones(Int32, 5, 5) - b = cuNumeric.ones(Float32, 5, 5) - b_jl = ones(Float32, 5, 5) - @test_throws ArgumentError a * a - @test @allowscalar cuNumeric.compare(a_jl * b_jl, a * b, 0.0, max_diff) + b = a * a + b_jl = a_jl * a_jl + @test @allowscalar cuNumeric.compare(b_jl, b, 0.0, max_diff) return end