Skip to content
97 changes: 51 additions & 46 deletions src/arraytypes/dictencoding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,55 +142,58 @@ function arrowvector(
kw...,
)
id = x.encoding.id
# XXX This is a race condition if two workers hit this block at the same time, then they'll create
# distinct locks
if !haskey(de, id)
de[id] = Lockable(x.encoding)
else
encodinglockable = de[id]
Base.@lock encodinglockable begin
encoding = encodinglockable.value
# in this case, we just need to check if any values in our local pool need to be delta dicationary serialized
deltas = setdiff(x.encoding, encoding)
if !isempty(deltas)
ET = indextype(encoding)
if length(deltas) + length(encoding) > typemax(ET)
error(
"fatal error serializing dict encoded column with ref index type of $ET; subsequent record batch unique values resulted in $(length(deltas) + length(encoding)) unique values, which exceeds possible index values in $ET",
)
end
data = arrowvector(
deltas,
i,
nl,
fi,
de,
ded,
nothing;
dictencode=dictencodenested,
dictencodenested=dictencodenested,
dictencoding=true,
kw...,
return x
end

encodinglockable = de[id]
Base.@lock encodinglockable begin
encoding = encodinglockable.value
# in this case, we just need to check if any values in our local pool need to be delta dicationary serialized
deltas = setdiff(x.encoding, encoding)
if !isempty(deltas)
ET = indextype(encoding)
if length(deltas) + length(encoding) > typemax(ET)
error(
"fatal error serializing dict encoded column with ref index type of $ET; subsequent record batch unique values resulted in $(length(deltas) + length(encoding)) unique values, which exceeds possible index values in $ET",
)
push!(
ded,
DictEncoding{eltype(data),ET,typeof(data)}(
id,
data,
false,
getmetadata(data),
),
end
data = arrowvector(
deltas,
i,
nl,
fi,
de,
ded,
nothing;
dictencode=dictencodenested,
dictencodenested=dictencodenested,
dictencoding=true,
kw...,
)
push!(
ded,
DictEncoding{eltype(data),ET,typeof(data)}(
id,
data,
false,
getmetadata(data),
),
)
if typeof(encoding.data) <: ChainedVector
append!(encoding.data, data)
else
data2 = ChainedVector([encoding.data, data])
encoding = DictEncoding{eltype(data2),ET,typeof(data2)}(
id,
data2,
false,
getmetadata(encoding),
)
if typeof(encoding.data) <: ChainedVector
append!(encoding.data, data)
else
data2 = ChainedVector([encoding.data, data])
encoding = DictEncoding{eltype(data2),ET,typeof(data2)}(
id,
data2,
false,
getmetadata(encoding),
)
de[id] = Lockable(encoding)
end
de[id] = Lockable(encoding, encodinglockable.lock)
end
end
end
Expand All @@ -215,6 +218,8 @@ function arrowvector(
x = x.data
len = length(x)
validity = ValidityBitmap(x)
# XXX This is a race condition if two workers hit this block at the same time, then they'll create
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@quinnj I think there is a race condition baked into the current architecture that can't be addressed without a very large refactoring. The current architecture creates the locks on a worker thread if they don't already exist, which means that threads are competing for the creation of the initial lock. The locks should be created before any tasks are spawned.

# distinct locks
if !haskey(de, id)
# dict encoding doesn't exist yet, so create for 1st time
if DataAPI.refarray(x) === x || DataAPI.refpool(x) === nothing
Expand Down Expand Up @@ -326,7 +331,7 @@ function arrowvector(
false,
getmetadata(encoding),
)
de[id] = Lockable(encoding)
de[id] = Lockable(encoding, encodinglockable.lock)
end
end
end
Expand Down
46 changes: 24 additions & 22 deletions src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,27 +295,29 @@ function write(writer::Writer, source)
recbatchmsg = makerecordbatchmsg(writer.schema[], cols, writer.alignment)
put!(writer.msgs, recbatchmsg)
else
if writer.threaded
@wkspawn process_partition(
tblcols,
writer.dictencodings,
writer.largelists,
writer.compress,
writer.denseunions,
writer.dictencode,
writer.dictencodenested,
writer.maxdepth,
writer.sync,
writer.msgs,
writer.alignment,
$(writer.partition_count),
writer.schema,
writer.errorref,
writer.anyerror,
writer.meta,
writer.colmeta,
)
else
# XXX There is a race condition in the processing of dict encodings
# so we disable multithreaded writing until that can be addressed. See #582
# if writer.threaded
# @wkspawn process_partition(
# tblcols,
# writer.dictencodings,
# writer.largelists,
# writer.compress,
# writer.denseunions,
# writer.dictencode,
# writer.dictencodenested,
# writer.maxdepth,
# writer.sync,
# writer.msgs,
# writer.alignment,
# $(writer.partition_count),
# writer.schema,
# writer.errorref,
# writer.anyerror,
# writer.meta,
# writer.colmeta,
# )
# else
@async process_partition(
tblcols,
writer.dictencodings,
Expand All @@ -335,7 +337,7 @@ function write(writer::Writer, source)
writer.meta,
writer.colmeta,
)
end
# end
end
writer.partition_count += 1
end
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
Expand All @@ -44,4 +45,5 @@ PooledArrays = "1"
StructTypes = "1"
SentinelArrays = "1"
Tables = "1"
TestSetExtensions = "3"
TimeZones = "1"
16 changes: 11 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ using DataAPI
using FilePathsBase
using DataFrames
import Random: randstring
using TestSetExtensions: ExtendedTestSet
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given how long the Arrow tests take, it's useful to have some indication of progress so that we can tell if tests have hung. ExtendedTestSet shows a . for each completed test.

(We also get colored diffs of arrays when tests fail, which is nice.)


# this formulation tests the loaded ArrowTypes, even if it's not the dev version
# within the mono-repo
include(joinpath(dirname(pathof(ArrowTypes)), "../test/tests.jl"))
include(joinpath(dirname(pathof(Arrow)), "../test/testtables.jl"))
include(joinpath(dirname(pathof(Arrow)), "../test/testappend.jl"))
include(joinpath(dirname(pathof(Arrow)), "../test/integrationtest.jl"))
include(joinpath(dirname(pathof(Arrow)), "../test/dates.jl"))

include(joinpath(@__DIR__, "testtables.jl"))
include(joinpath(@__DIR__, "testappend.jl"))
include(joinpath(@__DIR__, "integrationtest.jl"))
include(joinpath(@__DIR__, "dates.jl"))

struct CustomStruct
x::Int
Expand All @@ -45,7 +49,7 @@ struct CustomStruct2{sym}
x::Int
end

@testset "Arrow" begin
@testset ExtendedTestSet "Arrow" begin
@testset "table roundtrips" begin
for case in testtables
testtable(case...)
Expand Down Expand Up @@ -381,6 +385,8 @@ end
end

@testset "# 126" begin
# XXX This test also captures a race condition in multithreaded
# writes of dictionary encoded arrays
t = Tables.partitioner((
(a=Arrow.toarrowvector(PooledArray([1, 2, 3])),),
(a=Arrow.toarrowvector(PooledArray([1, 2, 3, 4])),),
Expand Down
Loading