Skip to content

Commit a7ccfd5

Browse files
authored
Support reading from and writing to Arrow files (#415)
This requires overriding `Arrow.DictEncoding` so that an `Arrow.DictEncoded` with a `CategoricalArray` dictionary with one entry per level is created. This is the only way to ensure that indexing the Arrow column gives `CategoricalValue` objects. In practice such columns will most often be used after conversion to `CategoricalArray` via `copy`, `DataFrame`, etc. Apparently, pandas do not allow reading the resulting file if the array allows for missing values as it does not accept `missing` in the dictionary. Instead it would need missing entries to be coded via null indices, which is less efficient. Require Julia 1.6 as tests fail on older Julia versions.
1 parent 8bfc647 commit a7ccfd5

File tree

5 files changed

+131
-3
lines changed

5 files changed

+131
-3
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
fail-fast: false
1313
matrix:
1414
version:
15-
- '1.0'
15+
- '1.6'
1616
- '1' # automatically expands to the latest stable 1.x release of Julia
1717
- 'nightly'
1818
os:

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,21 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
1414

1515
[weakdeps]
16+
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
1617
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1718
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1819
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
1920
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
2021

2122
[extensions]
23+
CategoricalArraysArrowExt = "Arrow"
2224
CategoricalArraysJSONExt = "JSON"
2325
CategoricalArraysRecipesBaseExt = "RecipesBase"
2426
CategoricalArraysSentinelArraysExt = "SentinelArrays"
2527
CategoricalArraysStructTypesExt = "StructTypes"
2628

2729
[compat]
30+
Arrow = "2"
2831
Compat = "3.37, 4"
2932
DataAPI = "1.6"
3033
JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21"
@@ -35,9 +38,10 @@ Requires = "1"
3538
SentinelArrays = "1"
3639
Statistics = "1"
3740
StructTypes = "1"
38-
julia = "1"
41+
julia = "1.6"
3942

4043
[extras]
44+
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
4145
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
4246
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
4347
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
@@ -49,4 +53,4 @@ StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
4953
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5054

5155
[targets]
52-
test = ["Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StructTypes", "Test"]
56+
test = ["Arrow", "Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StructTypes", "Test"]

ext/CategoricalArraysArrowExt.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
module CategoricalArraysArrowExt
2+
3+
using CategoricalArrays
4+
import Arrow
5+
import Arrow: ArrowTypes
6+
7+
const CATARRAY_ARROWNAME = Symbol("JuliaLang.CategoricalArrays.CategoricalArray")
8+
ArrowTypes.arrowname(::Type{<:CategoricalValue}) = CATARRAY_ARROWNAME
9+
ArrowTypes.arrowmetadata(::Type{CategoricalValue{T, R}}) where {T, R} = string(R)
10+
11+
ArrowTypes.arrowname(::Type{Union{<:CategoricalValue, Missing}}) = CATARRAY_ARROWNAME
12+
ArrowTypes.arrowmetadata(::Type{Union{CategoricalValue{T, R}, Missing}}) where {T, R} =
13+
string(R)
14+
15+
const REFTYPES = Dict(string(T) => T for T in (Int128, Int16, Int32, Int64, Int8, UInt128,
16+
UInt16, UInt32, UInt64, UInt8))
17+
function ArrowTypes.JuliaType(::Val{CATARRAY_ARROWNAME},
18+
::Type{S}, meta::String) where S
19+
R = REFTYPES[meta]
20+
return CategoricalValue{S, R}
21+
end
22+
23+
for (MV, MT) in ((:V, :T), (:(Union{V,Missing}), :(Union{T,Missing})))
24+
@eval begin
25+
function Arrow.DictEncoding{$MV,S,A}(id, data::Arrow.List{U, O, B},
26+
isOrdered, metadata) where
27+
{T, R, V<:CategoricalValue{T,R}, S, O, A, B, U}
28+
newdata = Arrow.List{$MT,O,B}(data.arrow, data.validity, data.offsets,
29+
data.data, data.ℓ, data.metadata)
30+
levels = Missing <: $MT ? collect(skipmissing(newdata)) : newdata
31+
catdata = CategoricalVector{$MT,R}(newdata, levels=levels)
32+
return Arrow.DictEncoding{$MV,S,typeof(catdata)}(id, catdata,
33+
isOrdered, metadata)
34+
end
35+
36+
function Arrow.DictEncoding{$MV,S,A}(id, data::Arrow.Primitive{U, B},
37+
isOrdered, metadata) where
38+
{T, R, V<:CategoricalValue{T,R}, S, A, B, U}
39+
newdata = Arrow.Primitive{$MT,B}(data.arrow, data.validity, data.data,
40+
data.ℓ, data.metadata)
41+
levels = Missing <: $MT ? collect(skipmissing(newdata)) : newdata
42+
catdata = CategoricalVector{$MT,R}(newdata, levels=levels)
43+
return Arrow.DictEncoding{$MV,S,typeof(catdata)}(id, catdata,
44+
isOrdered, metadata)
45+
end
46+
end
47+
end
48+
49+
function Base.copy(x::Arrow.DictEncoded{V}) where {T, R, V<:CategoricalValue{T, R}}
50+
pool = CategoricalPool{T,R}(x.encoding.data)
51+
inds = x.indices
52+
refs = similar(inds, R)
53+
refs .= inds .+ one(R)
54+
return CategoricalVector{T}(refs, pool)
55+
end
56+
57+
function Base.copy(x::Arrow.DictEncoded{Union{Missing,V}}) where
58+
{T, R, V<:CategoricalValue{T, R}}
59+
ismissing(x.encoding.data[1]) ||
60+
throw(ErrorException("`missing` must be the first value in a " *
61+
"`CategoricalArray` pool"))
62+
levels = collect(skipmissing(x.encoding.data))
63+
pool = CategoricalPool{T,R}(levels)
64+
inds = x.indices
65+
refs = similar(inds, R)
66+
refs .= inds
67+
return CategoricalVector{Union{T,Missing}}(refs, pool)
68+
end
69+
70+
end

src/CategoricalArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ module CategoricalArrays
4141

4242
@static if !isdefined(Base, :get_extension)
4343
function __init__()
44+
@require Arrow="69666777-d1a9-59fb-9406-91d4454c9d45" include("../ext/CategoricalArraysArrowExt.jl")
4445
@require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl")
4546
@require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl")
4647
@require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl")

test/13_arraycommon.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ using StructTypes
1010
using RecipesBase
1111
using Plots
1212
using SentinelArrays
13+
using Arrow
14+
using Missings
1315

1416
const = isequal
1517
const = !isequal
@@ -2071,6 +2073,57 @@ StructTypes.StructType(::Type{<:MyCustomType}) = StructTypes.Struct()
20712073
@test levels(readx.var) == levels(x.var)
20722074
end
20732075

2076+
if Int == Int64
2077+
@testset "writing and reading Arrow files" for f in (identity, passmissing(string))
2078+
xref = f.([3, 1, 4, 1, 4])
2079+
x = categorical(f.([3, 1, 4, 1, 4]))
2080+
tbl = mktemp() do path, io
2081+
Arrow.write(path, (x=x,))
2082+
Arrow.Table(path)
2083+
end
2084+
@test tbl.x == x
2085+
@test tbl.x isa Arrow.DictEncoded{CategoricalValue{eltype(xref), UInt32}, Int8,
2086+
<: CategoricalVector{eltype(xref), UInt32}}
2087+
@test copy(tbl.x) == x
2088+
@test copy(x) isa CategoricalArray{eltype(xref),1,UInt32}
2089+
2090+
x = categorical(f.([3, 1, 4, 1, 4]), compress=true)
2091+
tbl = mktemp() do path, io
2092+
Arrow.write(path, (x=x,))
2093+
Arrow.Table(path)
2094+
end
2095+
@test tbl.x == x
2096+
@test tbl.x isa Arrow.DictEncoded{CategoricalValue{eltype(xref), UInt8}, Int8,
2097+
<: CategoricalVector{eltype(xref), UInt8}}
2098+
@test copy(tbl.x) == x
2099+
@test copy(x) isa CategoricalArray{eltype(xref),1,UInt8}
2100+
2101+
x = categorical(recode(xref, 1 => missing))
2102+
tbl = mktemp() do path, io
2103+
Arrow.write(path, (x=x,))
2104+
Arrow.Table(path)
2105+
end
2106+
@test tbl.x x
2107+
@test tbl.x isa Arrow.DictEncoded{Union{CategoricalValue{eltype(xref), UInt32}, Missing},
2108+
Int8,
2109+
<: CategoricalVector{Union{eltype(xref), Missing},
2110+
UInt32}}
2111+
@test copy(tbl.x) x
2112+
@test copy(x) isa CategoricalArray{Union{eltype(xref), Missing},1,UInt32}
2113+
2114+
recode!(x, missing => f(1))
2115+
tbl = mktemp() do path, io
2116+
Arrow.write(path, (x=x,))
2117+
Arrow.Table(path)
2118+
end
2119+
@test tbl.x == x
2120+
@test tbl.x isa Arrow.DictEncoded{Union{CategoricalValue{eltype(xref), UInt32}, Missing}, Int8,
2121+
<: CategoricalVector{Union{eltype(xref), Missing}, UInt32}}
2122+
@test copy(tbl.x) == x
2123+
@test copy(x) isa CategoricalArray{Union{eltype(xref), Missing},1,UInt32}
2124+
end
2125+
end
2126+
20742127
@testset "refarray, refvalue, refpool, and invrefpool" begin
20752128
for y in (categorical(["b", "a", "c", "b"]),
20762129
view(categorical(["a", "a", "c", "b"]), 1:3),

0 commit comments

Comments
 (0)