Skip to content

Commit e649cba

Browse files
authored
Merge pull request #753 from JuliaAI/option-to-force-error-in-machine-constructor
Allow user to control the level of scitype checks in `machine` constructor
2 parents 2b5f40a + 1fe75a5 commit e649cba

File tree

13 files changed

+183
-75
lines changed

13 files changed

+183
-75
lines changed

src/MLJBase.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ using Statistics, LinearAlgebra, Random, InteractiveUtils
9494
# ===================================================================
9595
## CONSTANTS
9696

97+
# for variable global constants, see src/init.jl
98+
9799
const PREDICT_OPERATIONS = (:predict,
98100
:predict_mean,
99101
:predict_mode,
@@ -181,7 +183,7 @@ const LOSS_FUNCTIONS = vcat(MARGIN_LOSSES, DISTANCE_LOSSES)
181183
# default_resource allows to switch the mode of parallelization
182184

183185
default_resource() = DEFAULT_RESOURCE[]
184-
default_resource(res) = (DEFAULT_RESOURCE[] = res)
186+
default_resource(res) = (DEFAULT_RESOURCE[] = res;)
185187

186188
# ===================================================================
187189
# Includes
@@ -297,7 +299,7 @@ export flat_values, recursive_setproperty!,
297299
recursive_getproperty, pretty, unwind
298300

299301
# show.jl
300-
export HANDLE_GIVEN_ID, @more, @constant, @bind, color_on, color_off
302+
export HANDLE_GIVEN_ID, @more, @constant, color_on, color_off
301303

302304
# datasets.jl:
303305
export load_boston, load_ames, load_iris, load_sunspots,
@@ -309,7 +311,7 @@ export load_boston, load_ames, load_iris, load_sunspots,
309311
export source, Source, CallableReturning
310312

311313
# machines.jl:
312-
export machine, Machine, fit!, report, fit_only!
314+
export machine, Machine, fit!, report, fit_only!, default_scitype_check_level
313315

314316
# datasets_synthetics.jl
315317
export make_blobs, make_moons, make_circles, make_regression

src/composition/learning_networks/nodes.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,10 @@ function _formula(stream, X::Node, depth, indent)
249249
n_args = length(X.args)
250250
if X.machine !== nothing
251251
print(stream, crind(indent + length(operation_name) - anti))
252-
printstyled(IOContext(stream, :color=>SHOW_COLOR),
252+
printstyled(IOContext(stream, :color=>SHOW_COLOR[]),
253253
# handle(X.machine),
254254
X.machine,
255-
bold=SHOW_COLOR)
255+
bold=SHOW_COLOR[])
256256
n_args == 0 || print(stream, ", ")
257257
end
258258
for k in 1:n_args
@@ -277,7 +277,7 @@ function Base.show(io::IO, ::MIME"text/plain", X::Node)
277277
print(io, " formula:\n")
278278
_formula(io, X, 4)
279279
# print(io, " ")
280-
# printstyled(IOContext(io, :color=>SHOW_COLOR),
280+
# printstyled(IOContext(io, :color=>SHOW_COLOR[]),
281281
# handle(X),
282282
# color=color(X))
283283
end

src/init.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
function __init__()
2+
global HANDLE_GIVEN_ID = Dict{UInt64,Symbol}()
23
global DEFAULT_RESOURCE = Ref{AbstractResource}(CPU1())
4+
global DEFAULT_SCITYPE_CHECK_LEVEL = Ref{Int}(1)
5+
global SHOW_COLOR = Ref{Bool}(true)
36

47
# for testing asynchronous training of learning networks:
58
global TESTING = parse(Bool, get(ENV, "TEST_MLJBASE", "false"))

src/machines.jl

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
1+
## SCITYPE CHECK LEVEL
2+
3+
"""
4+
default_scitype_check_level()
5+
6+
Return the current global default value for scientific type checking
7+
when constructing machines.
8+
9+
default_scitype_check_level(i::Integer)
10+
11+
Set the global default value for scientific type checking to `i`.
12+
13+
The effect of the `scitype_check_level` option in calls of the form
14+
`machine(model, data, scitype_check_level=...)` is summarized below:
15+
16+
`scitype_check_level` | Inspect scitypes? | If `Unknown` in scitypes | If other scitype mismatch |
17+
|:-------------------:|:-----------------:|:------------------------:|:-------------------------:|
18+
0 | × | | |
19+
1 (value at startup) | ✓ | | warning |
20+
2 | ✓ | warning | warning |
21+
3 | ✓ | warning | error |
22+
4 | ✓ | error | error |
23+
24+
See also [`machine`](@ref)
25+
26+
"""
27+
function default_scitype_check_level end
28+
default_scitype_check_level() = DEFAULT_SCITYPE_CHECK_LEVEL[]
29+
default_scitype_check_level(i) = (DEFAULT_SCITYPE_CHECK_LEVEL[] = i;)
30+
31+
132
## MACHINE TYPE
233

334
struct NotTrainedError{M} <: Exception
@@ -87,42 +118,55 @@ function _contains_unknown(F::Type{<:Tuple})
87118
return any(_contains_unknown, F.parameters)
88119
end
89120

90-
warn_generic_scitype_mismatch(S, F, T) =
121+
alert_generic_scitype_mismatch(S, F, T) =
91122
"The number and/or types of data arguments do not " *
92-
"match what the specified model supports.\n"*
93-
"Run `@doc $T` to learn more about your model's requirements.\n\n"*
123+
"match what the specified model supports. Suppress this "*
124+
"type check by specifying `scitype_check_level=0`.\n\n"*
125+
"Run `@doc $T` to learn more about your model's requirements.\n"*
94126
"Commonly, but non exclusively, supervised models are constructed " *
95127
"using the syntax `machine(model, X, y)` or `machine(model, X, y, w)` " *
96-
"while most other models with `machine(model, X)`. " *
97-
"Here `X` are features, `y` a target, and `w` sample or class weights.\n" *
98-
"In general, data in `machine(model, data...)` must satisfy " *
99-
"`scitype(data) <: MLJ.fit_data_scitype(model)` unless the " *
100-
"right-hand side contains `Unknown` scitypes.\n"*
128+
"while most other models are constructed with `machine(model, X)`. " *
129+
"Here `X` are features, `y` a target, and `w` sample or class weights.\n\n" *
130+
"In general, data in `machine(model, data...)` is expected to satisfy " *
131+
"`scitype(data) <: MLJ.fit_data_scitype(model)`.\n"*
101132
"In the present case:\n"*
102133
"scitype(data) = $S\n"*
103134
"fit_data_scitype(model) = $F\n"
104135

136+
const WARN_UNKNOWN_SCITYPE =
137+
"Some data contains `Unknown` scitypes, which might lead to model-data mismatches. "
138+
105139
err_length_mismatch(model) = DimensionMismatch(
106140
"Differing number of observations "*
107141
"in input and target. ")
108142

109-
check(model::Any, args...; kwargs...) =
143+
check(model::Any, args...) =
110144
throw(ArgumentError("Expected a `Model` instance, got $model. "))
111-
function check(model::Model, args...; full=false)
112-
nowarns = true
145+
function check(model::Model, scitype_check_level, args...)
146+
147+
is_okay = true
148+
149+
scitype_check_level >= 1 || return is_okay
113150

114151
F = fit_data_scitype(model)
115152

116-
# skip checks if `Unknown` scitypes appear anywhere in
117-
# `fit_data_scitype(model)`:
118-
_contains_unknown(F) && return true
153+
if _contains_unknown(F)
154+
scitype_check_level in [2, 3] && @warn WARN_UNKNOWN_SCITYPE
155+
scitype_check_level >= 4 && throw(ArgumentError(WARN_UNKNOWN_SCITYPE))
156+
return is_okay
157+
end
119158

120159
# we use `elscitype` here instead of `scitype` because the data is
121160
# wrapped in source nodes:
122161
S = Tuple{elscitype.(args)...}
123162
if !(S <: F)
124-
@warn warn_generic_scitype_mismatch(S, F, typeof(model))
125-
nowarns = false
163+
is_okay = false
164+
message = alert_generic_scitype_mismatch(S, F, typeof(model))
165+
if scitype_check_level >= 3
166+
throw(ArgumentError(message))
167+
else
168+
@warn message
169+
end
126170
end
127171

128172
if length(args) > 1 && is_supervised(model)
@@ -132,18 +176,19 @@ function check(model::Model, args...; full=false)
132176
scitype(X) == CallableReturning{Nothing} || nrows(X()) == nrows(y()) ||
133177
throw(err_length_mismatch(model))
134178
end
135-
return nowarns
179+
return is_okay
136180
end
137181

138182
"""
139-
machine(model, args...; cache=true)
183+
machine(model, args...; cache=true, scitype_check_level=1)
140184
141185
Construct a `Machine` object binding a `model`, storing
142186
hyper-parameters of some machine learning algorithm, to some data,
143-
`args`. Calling `fit!` on a `Machine` object stores in the machine
144-
object the outcomes of applying the algorithm. This in turn enables
145-
generalization to new data using operations such as `predict` or
146-
`transform`:
187+
`args`. Calling [`fit!`](@ref) on a `Machine` instance `mach` stores
188+
outcomes of applying the algorithm in `mach`, which can be inspected
189+
using `fitted_params(mach)` (learned paramters) and `report(mach)`
190+
(other outcomes). This in turn enables generalization to new data
191+
using operations such as `predict` or `transform`:
147192
148193
```julia
149194
using MLJModels
@@ -161,12 +206,24 @@ mach = machine(model, X, y)
161206
fit!(mach, rows=1:50)
162207
predict(mach, selectrows(X, 51:100)) # or predict(mach, rows=51:100)
163208
```
164-
165-
Specify `cache=false` to prioritize memory management over speed, and
166-
to guarantee data anonymity when serializing composite models.
209+
Specify `cache=false` to prioritize memory management over speed.
167210
168211
When building a learning network, `Node` objects can be substituted
169-
for the concrete data.
212+
for the concrete data but no type or dimension checks are applied.
213+
214+
### Checks on the types of training data
215+
216+
A model articulates its data requirements using [scientific
217+
types](https://juliaai.github.io/ScientificTypes.jl/dev/), i.e.,
218+
using the [`scitype`](@ref) function instead of the `typeof` function.
219+
220+
If `scitype_check_level > 0` then the scitype of each `arg` in `args`
221+
is computed, and this is compared with the scitypes expected by the
222+
model, unless `args` contains `Unknown` scitypes and
223+
`scitype_check_level < 4`, in which case no further action is
224+
taken. Whether warnings are issued or errors thrown depends the
225+
level. For details, see `default_scitype_check_level`](@ref), a method
226+
to inspect or change the default level (`1` at startup).
170227
171228
### Learning network machines
172229
@@ -274,7 +331,8 @@ r = report(network_mach)
274331
@assert r.accuracy == accuracy(yhat(), ys())
275332
```
276333
277-
See also [MLJBase.save](@ref), [`serializable`](@ref).
334+
See also [`fit!`](@ref), [`default_scitype_check_level`](@ref),
335+
[`MLJBase.save`](@ref), [`serializable`](@ref).
278336
279337
"""
280338
function machine end
@@ -307,9 +365,13 @@ machine(model::Model, arg1::AbstractNode, arg2, args...; kwargs...) =
307365
error("Mixing concrete data with `Node` training arguments "*
308366
"is not allowed. ")
309367

310-
function machine(model::Model, raw_arg1, raw_args...; kwargs...)
368+
function machine(model::Model,
369+
raw_arg1,
370+
raw_args...;
371+
scitype_check_level=default_scitype_check_level(),
372+
kwargs...)
311373
args = source.((raw_arg1, raw_args...))
312-
check(model, args...; full=true)
374+
check(model, scitype_check_level, args...;)
313375
return Machine(model, args...; kwargs...)
314376
end
315377

@@ -560,7 +622,8 @@ function fit_only!(mach::Machine{<:Model,cache_data};
560622
@warn "Some learning network source nodes are empty. "
561623
@info "Running type checks... "
562624
raw_args = map(N -> N(), mach.args)
563-
if check(mach.model, source.(raw_args)... ; full=true)
625+
scitype_check_level = 1
626+
if check(mach.model, scitype_check_level, source.(raw_args)...)
564627
@info "Type checks okay. "
565628
else
566629
@info "It seems an upstream node in a learning "*
@@ -772,8 +835,9 @@ all training data is removed and, if necessary, learned parameters are replaced
772835
with persistent representations.
773836
774837
Any general purpose Julia serializer may be applied to the output of
775-
`serializable` (eg, JLSO, BSON, JLD) but you must call `restore!(mach)` on
776-
the deserialised object `mach` before using it. See the example below.
838+
`serializable` (eg, JLSO, BSON, JLD) but you must call
839+
`restore!(mach)` on the deserialised object `mach` before using
840+
it. See the example below.
777841
778842
If using Julia's standard Serialization library, a shorter workflow is
779843
available using the [`save`](@ref) method.

src/measures/measure_search.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function Base.show(stream::IO, p::MeasureProxy)
1515
end
1616

1717
function Base.show(stream::IO, ::MIME"text/plain", p::MeasureProxy)
18-
printstyled(IOContext(stream, :color=> MLJBase.SHOW_COLOR),
18+
printstyled(IOContext(stream, :color=> MLJBase.SHOW_COLOR[]),
1919
p.docstring, bold=false, color=:magenta)
2020
println(stream)
2121
MLJBase.fancy_nt(stream, p)

src/resampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
547547
" per_observation, fitted_params_per_fold,\n"*
548548
" report_per_fold, train_test_rows")
549549
println(io, "Extract:")
550-
show_color = MLJBase.SHOW_COLOR
550+
show_color = MLJBase.SHOW_COLOR[]
551551
color_off()
552552
PrettyTables.pretty_table(io, data, header;
553553
header_crayon=PrettyTables.Crayon(bold=false),

src/show.jl

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
## REGISTERING LABELS OF OBJECTS DURING ASSIGNMENT
22

3-
const HANDLE_GIVEN_ID = Dict{UInt64,Symbol}()
4-
SHOW_COLOR = true
53
"""
64
color_on()
75
86
Enable color and bold output at the REPL, for enhanced display of MLJ objects.
97
108
"""
11-
color_on() = (global SHOW_COLOR=true;)
9+
color_on() = (SHOW_COLOR[] = true;)
1210
"""
1311
color_off()
1412
1513
Suppress color and bold output at the REPL for displaying MLJ objects.
1614
1715
"""
18-
color_off() = (global SHOW_COLOR=false;)
16+
color_off() = (SHOW_COLOR[] = false;)
1917

2018

2119
macro colon(p)
@@ -25,6 +23,8 @@ end
2523
"""
2624
@constant x = value
2725
26+
Private method (used in testing).
27+
2828
Equivalent to `const x = value` but registers the binding thus:
2929
3030
MLJBase.HANDLE_GIVEN_ID[objectid(value)] = :x
@@ -47,17 +47,6 @@ macro constant(ex)
4747
$(esc(handle))
4848
end
4949
end
50-
macro bind(ex)
51-
ex.head == :(=) || throw(error("Expression must be an assignment."))
52-
handle = ex.args[1]
53-
value = ex.args[2]
54-
quote
55-
$(esc(handle)) = $(esc(value))
56-
id = objectid($(esc(handle)))
57-
HANDLE_GIVEN_ID[id] = @colon $handle
58-
$(esc(handle))
59-
end
60-
end
6150

6251
"""to display abbreviated versions of integers"""
6352
function abbreviated(n)
@@ -154,7 +143,7 @@ function Base.show(stream::IO, object::MLJType)
154143
end
155144
show_handle(object) && (str *= " $(handle(object))")
156145
if false # !isempty(propertynames(object))
157-
printstyled(IOContext(stream, :color=> SHOW_COLOR),
146+
printstyled(IOContext(stream, :color=> SHOW_COLOR[]),
158147
str, bold=false, color=:blue)
159148
else
160149
print(stream, str)
@@ -209,7 +198,7 @@ function fancy(stream, object::MLJType, current_depth, depth, n)
209198
print(stream, ")")
210199
if current_depth == 0 && show_handle(object)
211200
description = " $(handle(object))"
212-
printstyled(IOContext(stream, :color=> SHOW_COLOR),
201+
printstyled(IOContext(stream, :color=> SHOW_COLOR[]),
213202
description, bold=false, color=:blue)
214203
end
215204
end

src/sources.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function Base.show(stream::IO, object::AbstractNode)
129129
str = simple_repr(typeof(object))
130130
show_handle(object) && (str *= " $(handle(object))")
131131
if false
132-
printstyled(IOContext(stream, :color=> SHOW_COLOR),
132+
printstyled(IOContext(stream, :color=> SHOW_COLOR[]),
133133
str, bold=false, color=:blue)
134134
else
135135
print(stream, str)

src/utilities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ function pretty(io::IO, X; showtypes=true, alignment=:l, kwargs...)
229229
else
230230
header = (names, )
231231
end
232-
show_color = MLJBase.SHOW_COLOR
232+
show_color = MLJBase.SHOW_COLOR[]
233233
color_off()
234234
try
235235
PrettyTables.pretty_table(io, MLJBase.matrix(X),

0 commit comments

Comments
 (0)