Skip to content

Commit d03cad0

Browse files
guangtaoguangtao
authored andcommitted
Optimize wrapped-column write paths and fix RAT
1 parent affcd1b commit d03cad0

File tree

9 files changed

+192
-18
lines changed

9 files changed

+192
-18
lines changed

dev/release/rat_exclude_files.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
Manifest.toml
19+
*/Manifest.toml
1920
dev/release/apache-rat-*.jar
2021
dev/release/filtered_rat.txt
2122
dev/release/rat.xml

src/ArrowTypes/src/ArrowTypes.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,12 @@ end
419419
# lazily call toarrow(x) on getindex for each x in data
420420
struct ToArrow{T,A} <: AbstractVector{T}
421421
data::A
422+
offset::Int
423+
needsconvert::Bool
424+
end
425+
function ToArrow{T,A}(data::A) where {T,A}
426+
needsconvert = !(eltype(A) === T && concrete_or_concreteunion(T))
427+
return ToArrow{T,A}(data, firstindex(data) - 1, needsconvert)
422428
end
423429

424430
concrete_or_concreteunion(T) =
@@ -464,7 +470,29 @@ function _convert(::Type{T}, x) where {T}
464470
return convert(T, x)
465471
end
466472
end
467-
Base.getindex(x::ToArrow{T}, i::Int) where {T} =
468-
_convert(T, toarrow(getindex(x.data, i + firstindex(x.data) - 1)))
473+
474+
@inline function _toarrowvalue(x::ToArrow{T}, value) where {T}
475+
x.needsconvert || return value
476+
return _convert(T, toarrow(value))
477+
end
478+
479+
Base.@propagate_inbounds function Base.getindex(x::ToArrow{T}, i::Int) where {T}
480+
value = @inbounds getindex(x.data, i + x.offset)
481+
return _toarrowvalue(x, value)
482+
end
483+
484+
function Base.iterate(x::ToArrow)
485+
state = iterate(x.data)
486+
state === nothing && return nothing
487+
value, st = state
488+
return _toarrowvalue(x, value), st
489+
end
490+
491+
function Base.iterate(x::ToArrow, st)
492+
state = iterate(x.data, st)
493+
state === nothing && return nothing
494+
value, st = state
495+
return _toarrowvalue(x, value), st
496+
end
469497

470498
end # module ArrowTypes

src/ArrowTypes/test/tests.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,26 +192,46 @@ end
192192
x = ArrowTypes.ToArrow([:hey, :ho])
193193
@test x isa ArrowTypes.ToArrow{String,Vector{Symbol}}
194194
@test eltype(x) == String
195+
@test getfield(x, :needsconvert)
196+
@test x[1] == "hey"
197+
@test collect(x) == ["hey", "ho"]
195198
@test x == ["hey", "ho"]
196199

197200
x = ArrowTypes.ToArrow(Any[1, 3.14])
198201
@test x isa ArrowTypes.ToArrow{Float64,Vector{Any}}
199202
@test eltype(x) == Float64
203+
@test collect(x) == [1.0, 3.14]
200204
@test x == [1.0, 3.14]
201205

202206
x = ArrowTypes.ToArrow(Any[1, 3.14, "hey"])
203207
@test x isa ArrowTypes.ToArrow{Union{Float64,String},Vector{Any}}
204208
@test eltype(x) == Union{Float64,String}
209+
@test collect(x) == Union{Float64,String}[1.0, 3.14, "hey"]
205210
@test x == [1.0, 3.14, "hey"]
206211

207212
x = ArrowTypes.ToArrow(OffsetArray([1, 2, 3], -3:-1))
208213
@test x isa ArrowTypes.ToArrow{Int,OffsetVector{Int,Vector{Int}}}
209214
@test eltype(x) == Int
215+
@test !getfield(x, :needsconvert)
216+
@test x[1] == 1
217+
@test x[3] == 3
218+
@test collect(x) == [1, 2, 3]
210219
@test x == [1, 2, 3]
211220

221+
x = ArrowTypes.ToArrow(OffsetArray(Union{Missing,Int}[1, missing], -3:-2))
222+
@test x isa ArrowTypes.ToArrow{Union{Missing,Int},OffsetVector{Union{Missing,Int},Vector{Union{Missing,Int}}}}
223+
@test !getfield(x, :needsconvert)
224+
@test x[1] == 1
225+
@test x[2] === missing
226+
@test isequal(collect(x), Union{Missing,Int}[1, missing])
227+
212228
x = ArrowTypes.ToArrow(OffsetArray(Any[1, 3.14], -3:-2))
213229
@test x isa ArrowTypes.ToArrow{Float64,OffsetVector{Any,Vector{Any}}}
214230
@test eltype(x) == Float64
231+
@test getfield(x, :needsconvert)
232+
@test x[1] == 1
233+
@test x[2] == 3.14
234+
@test collect(x) == [1.0, 3.14]
215235
@test x == [1, 3.14]
216236

217237
@testset "respect non-missing concrete type" begin

src/arraytypes/arraytypes.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,15 +201,10 @@ end
201201
Base.size(p::ValidityBitmap) = (p.ℓ,)
202202
nullcount(x::ValidityBitmap) = x.nc
203203

204-
function ValidityBitmap(x)
205-
T = eltype(x)
206-
if !(T >: Missing)
207-
return ValidityBitmap(UInt8[], 1, length(x), 0)
208-
end
204+
function _validitybitmap(x, len)
209205
len = length(x)
210206
blen = cld(len, 8)
211207
bytes = Vector{UInt8}(undef, blen)
212-
st = iterate(x)
213208
nc = 0
214209
b = 0xff
215210
j = k = 1
@@ -232,6 +227,23 @@ function ValidityBitmap(x)
232227
return ValidityBitmap(nc == 0 ? UInt8[] : bytes, 1, nc == 0 ? 0 : len, nc)
233228
end
234229

230+
function ValidityBitmap(x)
231+
T = eltype(x)
232+
if !(T >: Missing)
233+
return ValidityBitmap(UInt8[], 1, length(x), 0)
234+
end
235+
return _validitybitmap(x, length(x))
236+
end
237+
238+
function ValidityBitmap(x::ArrowTypes.ToArrow)
239+
T = eltype(x)
240+
if !(T >: Missing)
241+
return ValidityBitmap(UInt8[], 1, length(x), 0)
242+
end
243+
source = getfield(x, :needsconvert) ? x : getfield(x, :data)
244+
return _validitybitmap(source, length(x))
245+
end
246+
235247
@propagate_inbounds function Base.getindex(p::ValidityBitmap, i::Integer)
236248
# no boundscheck because parent array should do it
237249
# if a validity bitmap is empty, it either means:

src/arraytypes/bool.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ end
5252

5353
arrowvector(::BoolKind, x::BoolVector, i, nl, fi, de, ded, meta; kw...) = x
5454

55-
function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...)
56-
validity = ValidityBitmap(x)
57-
len = length(x)
55+
function _packboolbytes(x, len)
5856
blen = cld(len, 8)
5957
bytes = Vector{UInt8}(undef, blen)
6058
b = 0xff
@@ -74,6 +72,21 @@ function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...)
7472
if j > 1
7573
bytes[k] = b
7674
end
75+
return bytes
76+
end
77+
78+
function arrowvector(::BoolKind, x, i, nl, fi, de, ded, meta; kw...)
79+
validity = ValidityBitmap(x)
80+
len = length(x)
81+
bytes = _packboolbytes(x, len)
82+
return BoolVector{eltype(x)}(bytes, 1, validity, len, meta)
83+
end
84+
85+
function arrowvector(::BoolKind, x::ArrowTypes.ToArrow, i, nl, fi, de, ded, meta; kw...)
86+
validity = ValidityBitmap(x)
87+
len = length(x)
88+
source = getfield(x, :needsconvert) ? x : getfield(x, :data)
89+
bytes = _packboolbytes(source, len)
7790
return BoolVector{eltype(x)}(bytes, 1, validity, len, meta)
7891
end
7992

src/arraytypes/list.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,22 @@ end
219219
return x, (i, chunk, chunk_i, chunk_len, len)
220220
end
221221

222+
@inline function _writeuint8chunk(io::IO, bytes)
223+
GC.@preserve bytes begin
224+
return Base.unsafe_write(io, pointer(bytes), length(bytes))
225+
end
226+
end
227+
228+
function writearray(io::IO, ::Type{UInt8}, col::ToList{UInt8,stringtype}) where {stringtype}
229+
n = 0
230+
for chunk in col.data
231+
chunk === missing && continue
232+
bytes = stringtype ? _codeunits(chunk) : chunk
233+
n += _writeuint8chunk(io, bytes)
234+
end
235+
return n
236+
end
237+
222238
arrowvector(::ListKind, x::List, i, nl, fi, de, ded, meta; kw...) = x
223239

224240
function arrowvector(::ListKind, x, i, nl, fi, de, ded, meta; largelists::Bool=false, kw...)

src/utils.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,33 @@ end
3737
# efficient writing of arrays
3838
writearray(io, col) = writearray(io, maybemissing(eltype(col)), col)
3939

40+
function _writearrayfallback(io::IO, ::Type{T}, col) where {T}
41+
n = 0
42+
data = Vector{UInt8}(undef, sizeof(col))
43+
buf = IOBuffer(data; write=true)
44+
for x in col
45+
n += Base.write(buf, coalesce(x, ArrowTypes.default(T)))
46+
end
47+
n = Base.write(io, take!(buf))
48+
return n
49+
end
50+
51+
@inline function _writearraycontiguous(io::IO, ::Type{T}, data) where {T}
52+
return Base.unsafe_write(io, pointer(data), sizeof(T) * length(data))
53+
end
54+
55+
@inline function _contiguoustoarrowdata(::Type{T}, col::ArrowTypes.ToArrow) where {T}
56+
getfield(col, :needsconvert) && return nothing
57+
data = getfield(col, :data)
58+
strides(data) == (1,) || return nothing
59+
if data isa AbstractVector{T}
60+
return isbitstype(T) ? data : nothing
61+
elseif isbitstype(T) && data isa AbstractVector{Union{T,Missing}}
62+
return data
63+
end
64+
return nothing
65+
end
66+
4067
function writearray(io::IO, ::Type{T}, col) where {T}
4168
if col isa Vector{T}
4269
n = Base.write(io, col)
@@ -51,17 +78,17 @@ function writearray(io::IO, ::Type{T}, col) where {T}
5178
n += writearray(io, T, A)
5279
end
5380
else
54-
n = 0
55-
data = Vector{UInt8}(undef, sizeof(col))
56-
buf = IOBuffer(data; write=true)
57-
for x in col
58-
n += Base.write(buf, coalesce(x, ArrowTypes.default(T)))
59-
end
60-
n = Base.write(io, take!(buf))
81+
n = _writearrayfallback(io, T, col)
6182
end
6283
return n
6384
end
6485

86+
function writearray(io::IO, ::Type{T}, col::ArrowTypes.ToArrow) where {T}
87+
data = _contiguoustoarrowdata(T, col)
88+
isnothing(data) || return _writearraycontiguous(io, T, data)
89+
return _writearrayfallback(io, T, col)
90+
end
91+
6592
getbit(v::UInt8, n::Integer) = (v & (1 << (n - 1))) > 0x00
6693

6794
function setbit(v::UInt8, b::Bool, n::Integer)

test/Project.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
118
[deps]
219
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
320
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"

test/runtests.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,34 @@ end
295295
@test isequal(tt.a, ['a', missing])
296296
end
297297

298+
@testset "# offset bool write paths" begin
299+
t = (
300+
a=OffsetArray(Bool[true, false, true], -1:1),
301+
b=OffsetArray(Union{Missing,Bool}[true, missing, false], -1:1),
302+
c=OffsetArray(Any[true, false, true], -1:1),
303+
d=OffsetArray(Any[true, missing, false], -1:1),
304+
)
305+
tt = Arrow.Table(Arrow.tobuffer(t))
306+
@test tt.a == Bool[true, false, true]
307+
@test isequal(tt.b, Union{Missing,Bool}[true, missing, false])
308+
@test tt.c == Bool[true, false, true]
309+
@test isequal(tt.d, Union{Missing,Bool}[true, missing, false])
310+
end
311+
312+
@testset "# offset primitive write paths" begin
313+
t = (
314+
a=OffsetArray(Int64[1, 2, 3], -1:1),
315+
b=OffsetArray(Union{Missing,Int64}[1, missing, 3], -1:1),
316+
c=OffsetArray(Any[1, 2, 3], -1:1),
317+
d=OffsetArray(Any[1, missing, 3], -1:1),
318+
)
319+
tt = Arrow.Table(Arrow.tobuffer(t))
320+
@test tt.a == Int64[1, 2, 3]
321+
@test isequal(tt.b, Union{Missing,Int64}[1, missing, 3])
322+
@test tt.c == Int64[1, 2, 3]
323+
@test isequal(tt.d, Union{Missing,Int64}[1, missing, 3])
324+
end
325+
298326
@testset "# automatic custom struct serialization/deserialization" begin
299327
t = (col1=[CustomStruct(1, 2.3, "hey"), CustomStruct(4, 5.6, "there")],)
300328

@@ -974,6 +1002,18 @@ end
9741002
@test isequal(t1.bm, t2.bm)
9751003
@test isequal(t1.c, t2.c)
9761004
@test isequal(t1.cm, t2.cm)
1005+
1006+
toffset = (
1007+
b=OffsetArray([b"01", b"", b"3"], -1:1),
1008+
bm=OffsetArray(Union{Missing,Base.CodeUnits{UInt8,String}}[b"01", b"3", missing], -1:1),
1009+
c=OffsetArray(["a", "b", "c"], -1:1),
1010+
cm=OffsetArray(Union{Missing,String}["a", "c", missing], -1:1),
1011+
)
1012+
ttoffset = Arrow.Table(Arrow.tobuffer(toffset))
1013+
@test collect(toffset.b) == ttoffset.b
1014+
@test isequal(collect(toffset.bm), ttoffset.bm)
1015+
@test collect(toffset.c) == ttoffset.c
1016+
@test isequal(collect(toffset.cm), ttoffset.cm)
9771017
end
9781018

9791019
@testset "# 435" begin

0 commit comments

Comments
 (0)