Skip to content

Commit 52ba39d

Browse files
committed
✨ Add support for single vector
1 parent d0c67ac commit 52ba39d

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ScientificTypes = "3.0"
2626
StatsBase = "0.34"
2727
TableOperations = "1.2"
2828
Tables = "1.11"
29-
julia = "1.6"
29+
julia = "1.10"
3030

3131
[extras]
3232
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"

src/generic.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ logic?"
1616
- X: A table where the elements of the categorical features have [scitypes](https://juliaai.github.io/ScientificTypes.jl/dev/)
1717
`Multiclass` or `OrderedFactor`
1818
- features=[]: A list of names of categorical features given as symbols to exclude or include from encoding,
19-
according to the value of `ignore`
19+
according to the value of `ignore`, or a single symbol (which is treated as a vector with one symbol),
2020
or a callable that returns true for features to be included/excluded
2121
- ignore=true: Whether to exclude or includes the features given in features
2222
- ordered_factor=false: Whether to encode OrderedFactor or ignore them
@@ -40,8 +40,12 @@ function generic_fit(X,
4040
feat_names = Tables.schema(X).names
4141

4242
#2. Modify column_names based on features
43+
if features isa Symbol
44+
features = [features]
45+
end
46+
4347
if features isa AbstractVector{Symbol}
44-
# Original behavior for vector of symbols
48+
# Original behavior for vector of symbols
4549
feat_names =
4650
(ignore) ? setdiff(feat_names, features) : intersect(feat_names, features)
4751
else

test/generic.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,20 @@ end
182182
# Test 3: predicate with ordered_factor=true picks up ordered factors (e.g., :E)
183183
cache3 = dummy_encoder_fit(X, predicate; ignore=false, ordered_factor=true)
184184
@test Set(cache3[:encoded]) == Set([:A, :C, :E])
185+
end
186+
187+
@testset "Single Symbol and list of one symbol equivalence" begin
188+
X = dataset_forms[1]
189+
feat_names = Tables.schema(X).names
190+
191+
# Test 1: Single Symbol
192+
single_symbol = :A
193+
cache1 = dummy_encoder_fit(X, single_symbol; ignore=true, ordered_factor=false)
194+
@test !(:A in cache1[:encoded])
195+
# Test 2: List of one symbol
196+
single_symbol_list = [:A]
197+
cache2 = dummy_encoder_fit(X, single_symbol_list; ignore=true, ordered_factor=false)
198+
@test !(:A in cache2[:encoded])
199+
# Test 3: Both should yield the same result
200+
@test cache1[:encoded] == cache2[:encoded]
185201
end

0 commit comments

Comments
 (0)