|
| 1 | +module MLJModelInterface |
| 2 | + |
| 3 | +# ------------------------------------------------------------------------ |
| 4 | +# Dependency (note that ScientificTypes itself does not have dependencies) |
| 5 | +import ScientificTypes: trait |
| 6 | + |
| 7 | +# ------------------------------------------------------------------------ |
| 8 | +# Single export: matrix, everything else is qualified in MLJBase |
| 9 | +export matrix |
| 10 | + |
| 11 | +# ------------------------------------------------------------------------ |
| 12 | + |
| 13 | +abstract type Mode end |
| 14 | +struct LightInterface <: Mode end |
| 15 | +struct FullInterface <: Mode end |
| 16 | + |
| 17 | +const INTERFACE_MODE = Ref{Mode}(LightInterface()) |
| 18 | + |
| 19 | +set_interface_mode(m::Mode) = (INTERFACE_MODE[] = m) |
| 20 | + |
| 21 | +get_interface_mode() = INTERFACE_MODE[] |
| 22 | + |
| 23 | +struct InterfaceError <: Exception |
| 24 | + m::String |
| 25 | +end |
| 26 | + |
| 27 | +vtrait(X) = X |> trait |> Val |
| 28 | + |
| 29 | +""" |
| 30 | + matrix(X; transpose=false) |
| 31 | +
|
| 32 | +If `X <: AbstractMatrix`, return `X` or `permutedims(X)` if `transpose=true`. |
| 33 | +If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`. |
| 34 | +""" |
| 35 | +matrix(X; kw...) = matrix(vtrait(X), X, get_interface_mode(); kw...) |
| 36 | + |
| 37 | +matrix(::Val{:other}, X::AbstractMatrix, ::Mode; transpose=false) = |
| 38 | + transpose ? permutedims(X) : X |
| 39 | + |
| 40 | +matrix(::Val{:other}, X, ::Mode; kw...) = |
| 41 | + throw(ArgumentError("Function `matrix` only supports AbstractMatrix or " * |
| 42 | + "containers implementing the Tables interface.")) |
| 43 | + |
| 44 | +matrix(::Val{:table}, X, ::LightInterface; kw...) = |
| 45 | + throw(InterfaceError("Only `MLJModelInterface` loaded. Import `MLJBase`.")) |
| 46 | + |
| 47 | +end # module |
0 commit comments