Skip to content

Commit bab7664

Browse files
committed
✨ Add output type test for contrast encoder
1 parent 5808f8e commit bab7664

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

test/encoders/contrast_encoder.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ end
5151
cache = contrast_encoder_fit(X, [:name]; ignore=false, mode = :dummy)
5252
k = length(levels(X.name))
5353
contrast_matrix = get_dummy_contrast(k)
54-
print()
5554
for (i, level) in enumerate(levels(X.name))
56-
println(cache[:vector_given_value_given_feature])
5755
@test cache[:vector_given_value_given_feature][:name][level] == contrast_matrix[i, :]
5856
end
5957
end
@@ -289,4 +287,40 @@ end
289287

290288
# Test report
291289
@test report(mach) == (encoded_features = generic_cache[:encoded_features],)
290+
end
291+
292+
293+
@testset "Test Contrast Encoder Output Types" begin
294+
X = (
295+
name = categorical(["Ben", "John", "Mary", "John"]),
296+
height = [1.85, 1.67, 1.5, 1.67],
297+
favnum = categorical([7, 5, 10, 1]),
298+
age = [23, 23, 14, 23],
299+
)
300+
301+
methods = [:contrast, :dummy, :sum, :backward_diff, :helmert, :hypothesis]
302+
matrix_func = [buildrandomcontrast, nothing, nothing, nothing, nothing, buildrandomhypothesis]
303+
304+
for (i, method) in enumerate(methods)
305+
encoder = ContrastEncoder(
306+
features = [:name, :favnum],
307+
ignore = false,
308+
mode = method,
309+
buildmatrix=matrix_func[i]
310+
)
311+
mach = fit!(machine(encoder, X))
312+
Xnew = MMI.transform(mach, X)
313+
314+
# Test Consistency with Types
315+
scs = schema(Xnew).scitypes
316+
ts = schema(Xnew).types
317+
318+
# Check scitypes for previously continuos or categorical features
319+
@test all(scs[1:end-1] .== Continuous)
320+
@test all(t -> (t <: AbstractFloat) && isconcretetype(t), ts[1:end-1])
321+
# Check scitypes for previously Count feature
322+
last_type, last_sctype = ts[end], scs[end]
323+
@test last_type <: Integer && isconcretetype(last_type)
324+
@test last_sctype <: Count
325+
end
292326
end

0 commit comments

Comments
 (0)