1
+ # # SCITYPE CHECK LEVEL
2
+
3
+ """
4
+ default_scitype_check_level()
5
+
6
+ Return the current global default value for scientific type checking
7
+ when constructing machines.
8
+
9
+ default_scitype_check_level(i::Integer)
10
+
11
+ Set the global default value for scientific type checking to `i`.
12
+
13
+ The effect of the `scitype_check_level` option in calls of the form
14
+ `machine(model, data, scitype_check_level=...)` is summarized below:
15
+
16
+ `scitype_check_level` | Inspect scitypes? | If `Unknown` in scitypes | If other scitype mismatch |
17
+ |:-------------------:|:-----------------:|:------------------------:|:-------------------------:|
18
+ 0 | × | | |
19
+ 1 (value at startup) | ✓ | | warning |
20
+ 2 | ✓ | warning | warning |
21
+ 3 | ✓ | warning | error |
22
+ 4 | ✓ | error | error |
23
+
24
+ See also [`machine`](@ref)
25
+
26
+ """
27
+ function default_scitype_check_level end
28
+ default_scitype_check_level () = DEFAULT_SCITYPE_CHECK_LEVEL[]
29
+ default_scitype_check_level (i) = (DEFAULT_SCITYPE_CHECK_LEVEL[] = i;)
30
+
31
+
1
32
# # MACHINE TYPE
2
33
3
34
struct NotTrainedError{M} <: Exception
@@ -87,42 +118,55 @@ function _contains_unknown(F::Type{<:Tuple})
87
118
return any (_contains_unknown, F. parameters)
88
119
end
89
120
90
- warn_generic_scitype_mismatch (S, F, T) =
121
+ alert_generic_scitype_mismatch (S, F, T) =
91
122
" The number and/or types of data arguments do not " *
92
- " match what the specified model supports.\n " *
93
- " Run `@doc $T ` to learn more about your model's requirements.\n\n " *
123
+ " match what the specified model supports. Suppress this " *
124
+ " type check by specifying `scitype_check_level=0`.\n\n " *
125
+ " Run `@doc $T ` to learn more about your model's requirements.\n " *
94
126
" Commonly, but non exclusively, supervised models are constructed " *
95
127
" using the syntax `machine(model, X, y)` or `machine(model, X, y, w)` " *
96
- " while most other models with `machine(model, X)`. " *
97
- " Here `X` are features, `y` a target, and `w` sample or class weights.\n " *
98
- " In general, data in `machine(model, data...)` must satisfy " *
99
- " `scitype(data) <: MLJ.fit_data_scitype(model)` unless the " *
100
- " right-hand side contains `Unknown` scitypes.\n " *
128
+ " while most other models are constructed with `machine(model, X)`. " *
129
+ " Here `X` are features, `y` a target, and `w` sample or class weights.\n\n " *
130
+ " In general, data in `machine(model, data...)` is expected to satisfy " *
131
+ " `scitype(data) <: MLJ.fit_data_scitype(model)`.\n " *
101
132
" In the present case:\n " *
102
133
" scitype(data) = $S \n " *
103
134
" fit_data_scitype(model) = $F \n "
104
135
136
+ const WARN_UNKNOWN_SCITYPE =
137
+ " Some data contains `Unknown` scitypes, which might lead to model-data mismatches. "
138
+
105
139
err_length_mismatch (model) = DimensionMismatch (
106
140
" Differing number of observations " *
107
141
" in input and target. " )
108
142
109
- check (model:: Any , args... ; kwargs ... ) =
143
+ check (model:: Any , args... ) =
110
144
throw (ArgumentError (" Expected a `Model` instance, got $model . " ))
111
- function check (model:: Model , args... ; full= false )
112
- nowarns = true
145
+ function check (model:: Model , scitype_check_level, args... )
146
+
147
+ is_okay = true
148
+
149
+ scitype_check_level >= 1 || return is_okay
113
150
114
151
F = fit_data_scitype (model)
115
152
116
- # skip checks if `Unknown` scitypes appear anywhere in
117
- # `fit_data_scitype(model)`:
118
- _contains_unknown (F) && return true
153
+ if _contains_unknown (F)
154
+ scitype_check_level in [2 , 3 ] && @warn WARN_UNKNOWN_SCITYPE
155
+ scitype_check_level >= 4 && throw (ArgumentError (WARN_UNKNOWN_SCITYPE))
156
+ return is_okay
157
+ end
119
158
120
159
# we use `elscitype` here instead of `scitype` because the data is
121
160
# wrapped in source nodes:
122
161
S = Tuple{elscitype .(args)... }
123
162
if ! (S <: F )
124
- @warn warn_generic_scitype_mismatch (S, F, typeof (model))
125
- nowarns = false
163
+ is_okay = false
164
+ message = alert_generic_scitype_mismatch (S, F, typeof (model))
165
+ if scitype_check_level >= 3
166
+ throw (ArgumentError (message))
167
+ else
168
+ @warn message
169
+ end
126
170
end
127
171
128
172
if length (args) > 1 && is_supervised (model)
@@ -132,18 +176,19 @@ function check(model::Model, args...; full=false)
132
176
scitype (X) == CallableReturning{Nothing} || nrows (X ()) == nrows (y ()) ||
133
177
throw (err_length_mismatch (model))
134
178
end
135
- return nowarns
179
+ return is_okay
136
180
end
137
181
138
182
"""
139
- machine(model, args...; cache=true)
183
+ machine(model, args...; cache=true, scitype_check_level=1 )
140
184
141
185
Construct a `Machine` object binding a `model`, storing
142
186
hyper-parameters of some machine learning algorithm, to some data,
143
- `args`. Calling `fit!` on a `Machine` object stores in the machine
144
- object the outcomes of applying the algorithm. This in turn enables
145
- generalization to new data using operations such as `predict` or
146
- `transform`:
187
+ `args`. Calling [`fit!`](@ref) on a `Machine` instance `mach` stores
188
+ outcomes of applying the algorithm in `mach`, which can be inspected
189
+ using `fitted_params(mach)` (learned paramters) and `report(mach)`
190
+ (other outcomes). This in turn enables generalization to new data
191
+ using operations such as `predict` or `transform`:
147
192
148
193
```julia
149
194
using MLJModels
@@ -161,12 +206,24 @@ mach = machine(model, X, y)
161
206
fit!(mach, rows=1:50)
162
207
predict(mach, selectrows(X, 51:100)) # or predict(mach, rows=51:100)
163
208
```
164
-
165
- Specify `cache=false` to prioritize memory management over speed, and
166
- to guarantee data anonymity when serializing composite models.
209
+ Specify `cache=false` to prioritize memory management over speed.
167
210
168
211
When building a learning network, `Node` objects can be substituted
169
- for the concrete data.
212
+ for the concrete data but no type or dimension checks are applied.
213
+
214
+ ### Checks on the types of training data
215
+
216
+ A model articulates its data requirements using [scientific
217
+ types](https://juliaai.github.io/ScientificTypes.jl/dev/), i.e.,
218
+ using the [`scitype`](@ref) function instead of the `typeof` function.
219
+
220
+ If `scitype_check_level > 0` then the scitype of each `arg` in `args`
221
+ is computed, and this is compared with the scitypes expected by the
222
+ model, unless `args` contains `Unknown` scitypes and
223
+ `scitype_check_level < 4`, in which case no further action is
224
+ taken. Whether warnings are issued or errors thrown depends the
225
+ level. For details, see `default_scitype_check_level`](@ref), a method
226
+ to inspect or change the default level (`1` at startup).
170
227
171
228
### Learning network machines
172
229
@@ -274,7 +331,8 @@ r = report(network_mach)
274
331
@assert r.accuracy == accuracy(yhat(), ys())
275
332
```
276
333
277
- See also [MLJBase.save](@ref), [`serializable`](@ref).
334
+ See also [`fit!`](@ref), [`default_scitype_check_level`](@ref),
335
+ [`MLJBase.save`](@ref), [`serializable`](@ref).
278
336
279
337
"""
280
338
function machine end
@@ -307,9 +365,13 @@ machine(model::Model, arg1::AbstractNode, arg2, args...; kwargs...) =
307
365
error (" Mixing concrete data with `Node` training arguments " *
308
366
" is not allowed. " )
309
367
310
- function machine (model:: Model , raw_arg1, raw_args... ; kwargs... )
368
+ function machine (model:: Model ,
369
+ raw_arg1,
370
+ raw_args... ;
371
+ scitype_check_level= default_scitype_check_level (),
372
+ kwargs... )
311
373
args = source .((raw_arg1, raw_args... ))
312
- check (model, args... ; full = true )
374
+ check (model, scitype_check_level, args... ;)
313
375
return Machine (model, args... ; kwargs... )
314
376
end
315
377
@@ -560,7 +622,8 @@ function fit_only!(mach::Machine{<:Model,cache_data};
560
622
@warn " Some learning network source nodes are empty. "
561
623
@info " Running type checks... "
562
624
raw_args = map (N -> N (), mach. args)
563
- if check (mach. model, source .(raw_args)... ; full= true )
625
+ scitype_check_level = 1
626
+ if check (mach. model, scitype_check_level, source .(raw_args)... )
564
627
@info " Type checks okay. "
565
628
else
566
629
@info " It seems an upstream node in a learning " *
@@ -772,8 +835,9 @@ all training data is removed and, if necessary, learned parameters are replaced
772
835
with persistent representations.
773
836
774
837
Any general purpose Julia serializer may be applied to the output of
775
- `serializable` (eg, JLSO, BSON, JLD) but you must call `restore!(mach)` on
776
- the deserialised object `mach` before using it. See the example below.
838
+ `serializable` (eg, JLSO, BSON, JLD) but you must call
839
+ `restore!(mach)` on the deserialised object `mach` before using
840
+ it. See the example below.
777
841
778
842
If using Julia's standard Serialization library, a shorter workflow is
779
843
available using the [`save`](@ref) method.
0 commit comments