diff --git a/Project.toml b/Project.toml index 90e50c4..075e345 100644 --- a/Project.toml +++ b/Project.toml @@ -10,15 +10,22 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +[weakdeps] +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" + +[extensions] +OneHotArraysCategoricalArraysExt = "CategoricalArrays" + [compat] Adapt = "3.0, 4" CUDA = "4, 5" +CategoricalArrays = "0.10.8" ChainRulesCore = "1.13" Compat = "4.2" GPUArraysCore = "0.1, 0.2" NNlib = "0.8, 0.9" Zygote = "0.6.35" -julia = "1.6" +julia = "1.10" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -28,4 +35,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "CUDA", "JLArrays", "Random", "Zygote"] +test = ["Test", "CategoricalArrays", "CUDA", "JLArrays", "Random", "Zygote"] diff --git a/ext/OneHotArraysCategoricalArraysExt.jl b/ext/OneHotArraysCategoricalArraysExt.jl new file mode 100644 index 0000000..beb14be --- /dev/null +++ b/ext/OneHotArraysCategoricalArraysExt.jl @@ -0,0 +1,11 @@ +module OneHotArraysCategoricalArraysExt + +println("loading?") + +using OneHotArrays, CategoricalArrays + +OneHotArrays.OneHotArray(cv::CategoricalValue) = OneHotVector(cv.ref, length(cv.pool.levels)) + +OneHotArrays.OneHotArray(ca::CategoricalArray) = OneHotArray(ca.refs, length(ca.pool)) + +end # module diff --git a/test/ext_categorical.jl b/test/ext_categorical.jl new file mode 100644 index 0000000..d227c19 --- /dev/null +++ b/test/ext_categorical.jl @@ -0,0 +1,20 @@ +using Test, OneHotArrays, CategoricalArrays + +@testset "CategoricalArrays -> OneHotArrays" begin + cval = CategoricalArrays.CategoricalValue('b', CategoricalArray('a':'z')) + + @test OneHotArray(cval) isa OneHotVector + @test OneHotArray(cval) == (('a':'z') .== 'b') + + @test_broken OneHotVector(cval) isa OneHotVector # surely if OneHotArray works, subtypes should too + @test_broken convert(OneHotArray, cval) isa OneHotVector + @test_broken onehot(cval) isa OneHotVector # possibly we should define this? Instead? + + cvec = categorical(string.([:a, :b, :b, :c, :d, :e])) + + @test OneHotArray(cvec) isa OneHotMatrix + @test size(OneHotArray(cvec)) == (5, 6) + @test onecold(OneHotArray(cvec)) == [1, 2, 2, 3, 4, 5] + + @test_broken onehotbatch(cvec) isa OneHotMatrix # possibly we should define this? Instead? +end diff --git a/test/runtests.jl b/test/runtests.jl index f3e50d1..ca83ee8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,10 @@ end include("linalg.jl") end +@testset "Extensions" begin + include("ext_categorical.jl") +end + using Zygote import CUDA if CUDA.functional()