Skip to content

Commit 4ab2762

Browse files
authored
Merge pull request #67 from alan-turing-institute/scitype1
Bump [compat] ScientificTypes = "^1"
2 parents 93a309e + 8f55aa3 commit 4ab2762

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
99

1010
[compat]
11-
ScientificTypes = "^0.7,^0.8"
11+
ScientificTypes = "^1"
1212
julia = "1"
1313

1414
[extras]

test/model_api.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,44 @@ mutable struct APIx1 <: Static end
2525
#update fallback = fit
2626
@test update(m0, 1, 5, nothing, randn(2), 5) == (5, nothing, nothing)
2727
end
28+
29+
struct DummyUnivariateFinite end
30+
31+
mutable struct UnivariateFiniteFitter <: Probabilistic end
32+
33+
@testset "models fitting a distribution to data" begin
34+
35+
function MLJModelInterface.fit(model::UnivariateFiniteFitter,
36+
verbosity::Int, X, y)
37+
38+
fitresult = DummyUnivariateFinite()
39+
report = nothing
40+
cache = nothing
41+
42+
verbosity > 0 && @info "Fitted a $fitresult"
43+
44+
return fitresult, cache, report
45+
end
46+
47+
MLJModelInterface.predict(model::UnivariateFiniteFitter,
48+
fitresult,
49+
X) = fill(fitresult, length(X))
50+
51+
MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) =
52+
AbstractVector{Nothing}
53+
MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) =
54+
AbstractVector{<:Finite}
55+
56+
y =categorical(collect("aabbccaa"))
57+
X = fill(nothing, length(y))
58+
model = UnivariateFiniteFitter()
59+
fitresult, cache, report = MLJModelInterface.fit(model, 1, X, y)
60+
61+
@test cache == nothing
62+
@test report == nothing
63+
64+
ytest = y[1:3]
65+
yhat = predict(model, fitresult, fill(nothing, 3))
66+
@test yhat == fill(DummyUnivariateFinite(), 3)
67+
68+
end

0 commit comments

Comments
 (0)