Skip to content

Commit 65fa026

Browse files
committed
GH-48132: [Ruby] Add support for writing dictionary array
1 parent 7fcc0af commit 65fa026

File tree

7 files changed

+137
-64
lines changed

7 files changed

+137
-64
lines changed

ruby/red-arrow-format/lib/arrow-format/array.rb

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,21 @@ def to_a
508508
end
509509

510510
class DictionaryArray < Array
511+
attr_reader :indices_buffer
512+
attr_reader :dictionary
511513
def initialize(type, size, validity_buffer, indices_buffer, dictionary)
512514
super(type, size, validity_buffer)
513515
@indices_buffer = indices_buffer
514516
@dictionary = dictionary
515517
end
516518

519+
def each_buffer
520+
return to_enum(__method__) unless block_given?
521+
522+
yield(@validity_buffer)
523+
yield(@indices_buffer)
524+
end
525+
517526
def to_a
518527
values = []
519528
@dictionary.each do |dictionary_chunk|

ruby/red-arrow-format/lib/arrow-format/bitmap.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def initialize(buffer, n_values)
2424
end
2525

2626
def [](i)
27-
(@validity_buffer.get_value(:U8, i / 8) & (1 << (i % 8))) > 0
27+
(@buffer.get_value(:U8, i / 8) & (1 << (i % 8))) > 0
2828
end
2929

3030
def each

ruby/red-arrow-format/lib/arrow-format/field.rb

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,8 @@ def to_flatbuffers
3434
fb_field = FB::Field::Data.new
3535
fb_field.name = @name
3636
fb_field.nullable = @nullable
37-
if @type.is_a?(DictionaryType)
38-
fb_field.type = @type.value_type.to_flatbuffers
39-
dictionary_encoding = FB::DictionaryEncoding::Data.new
40-
dictionary_encoding.id = @dictionary_id
41-
int = FB::Int::Data.new
42-
int.bit_width = @type.index_type.bit_width
43-
int.signed = @type.index_type.signed?
44-
dictionary_encoding.index_type = int
45-
dictionary_encoding.ordered = @type.ordered?
46-
dictionary_encoding.dictionary_kind =
47-
FB::DictionaryKind::DENSE_ARRAY
48-
fb_field.dictionary = dictionary
37+
if @type.respond_to?(:build_fb_field)
38+
@type.build_fb_field(fb_field, self)
4939
else
5040
fb_field.type = @type.to_flatbuffers
5141
end

ruby/red-arrow-format/lib/arrow-format/file-writer.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def build_footer
4141
fb_footer = FB::Footer::Data.new
4242
fb_footer.version = FB::MetadataVersion::V5
4343
fb_footer.schema = @fb_schema
44-
# fb_footer.dictionaries = ... # TODO
44+
fb_footer.dictionaries = @fb_dictionary_blocks
4545
fb_footer.record_batches = @fb_record_batch_blocks
4646
# fb_footer.custom_metadata = ... # TODO
4747
FB::Footer.serialize(fb_footer)

ruby/red-arrow-format/lib/arrow-format/streaming-writer.rb

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,38 +29,26 @@ class StreamingWriter
2929
def initialize(output)
3030
@output = output
3131
@offset = 0
32+
@fb_dictionary_blocks = []
3233
@fb_record_batch_blocks = []
34+
@written_dictionary_offsets = {}
3335
end
3436

3537
def start(schema)
3638
write_message(build_metadata(schema.to_flatbuffers))
37-
# TODO: Write dictionaries
3839
end
3940

4041
def write_record_batch(record_batch)
41-
body_length = 0
42-
record_batch.all_buffers_enumerator.each do |buffer|
43-
body_length += aligned_buffer_size(buffer) if buffer
42+
record_batch.schema.fields.each_with_index do |field, i|
43+
next if field.dictionary_id.nil?
44+
dictionary_array = record_batch.columns[i]
45+
write_dictionary(field.dictionary_id, dictionary_array)
4446
end
45-
metadata = build_metadata(record_batch.to_flatbuffers, body_length)
46-
fb_block = FB::Block::Data.new
47-
fb_block.offset = @offset
48-
fb_block.meta_data_length =
49-
CONTINUATION.bytesize +
50-
MessagePullReader::METADATA_LENGTH_SIZE +
51-
metadata.bytesize
52-
fb_block.body_length = body_length
53-
@fb_record_batch_blocks << fb_block
54-
write_message(metadata) do
55-
record_batch.all_buffers_enumerator.each do |buffer|
56-
write_buffer(buffer) if buffer
57-
end
58-
end
59-
end
6047

61-
# TODO
62-
# def write_dictionary_delta(id, dictionary)
63-
# end
48+
write_record_batch_based_message(record_batch,
49+
record_batch.to_flatbuffers,
50+
@fb_record_batch_blocks)
51+
end
6452

6553
def finish
6654
write_data(EOS)
@@ -100,6 +88,53 @@ def build_metadata(header, body_length=0)
10088
metadata
10189
end
10290

91+
def write_record_batch_based_message(record_batch, fb_header, fb_blocks)
92+
body_length = 0
93+
record_batch.all_buffers_enumerator.each do |buffer|
94+
body_length += aligned_buffer_size(buffer) if buffer
95+
end
96+
metadata = build_metadata(fb_header, body_length)
97+
fb_block = FB::Block::Data.new
98+
fb_block.offset = @offset
99+
fb_block.meta_data_length =
100+
CONTINUATION.bytesize +
101+
MessagePullReader::METADATA_LENGTH_SIZE +
102+
metadata.bytesize
103+
fb_block.body_length = body_length
104+
fb_blocks << fb_block
105+
write_message(metadata) do
106+
record_batch.all_buffers_enumerator.each do |buffer|
107+
write_buffer(buffer) if buffer
108+
end
109+
end
110+
end
111+
112+
def write_dictionary(id, dictionary_array)
113+
value_type = dictionary_array.type.value_type
114+
dictionary = dictionary_array.dictionary
115+
116+
offset = @written_dictionary_offsets[id]
117+
if offset.nil?
118+
is_delta = false
119+
else
120+
is_delta = true
121+
raise NotImplementedError,
122+
"Delta dictionary message isn't implemented yet"
123+
end
124+
125+
schema = Schema.new([Field.new("dummy", value_type, true, nil)])
126+
size = dictionary.size
127+
record_batch = RecordBatch.new(schema, size, [dictionary])
128+
fb_dictionary_batch = FB::DictionaryBatch::Data.new
129+
fb_dictionary_batch.id = id
130+
fb_dictionary_batch.data = record_batch.to_flatbuffers
131+
fb_dictionary_batch.delta = is_delta
132+
write_record_batch_based_message(record_batch,
133+
fb_dictionary_batch,
134+
@fb_dictionary_blocks)
135+
@written_dictionary_offsets[id] = dictionary_array.dictionary.size
136+
end
137+
103138
def write_message(metadata)
104139
write_data(CONTINUATION)
105140
metadata_size = metadata.bytesize

ruby/red-arrow-format/lib/arrow-format/type.rb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,5 +873,19 @@ def build_array(size, validity_buffer, indices_buffer, dictionary)
873873
indices_buffer,
874874
dictionary)
875875
end
876+
877+
def build_fb_field(fb_field, field)
878+
fb_dictionary_encoding = FB::DictionaryEncoding::Data.new
879+
fb_dictionary_encoding.id = field.dictionary_id
880+
fb_int = FB::Int::Data.new
881+
fb_int.bit_width = @index_type.bit_width
882+
fb_int.signed = @index_type.signed?
883+
fb_dictionary_encoding.index_type = fb_int
884+
fb_dictionary_encoding.ordered = @ordered
885+
fb_dictionary_encoding.dictionary_kind =
886+
FB::DictionaryKind::DENSE_ARRAY
887+
fb_field.type = @value_type.to_flatbuffers
888+
fb_field.dictionary = fb_dictionary_encoding
889+
end
876890
end
877891
end

ruby/red-arrow-format/test/test-writer.rb

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,30 @@ def convert_type(red_arrow_type)
106106
convert_field(field)
107107
end
108108
ArrowFormat::SparseUnionType.new(fields, red_arrow_type.type_codes)
109+
when Arrow::DictionaryDataType
110+
index_type = convert_type(red_arrow_type.index_data_type)
111+
type = convert_type(red_arrow_type.value_data_type)
112+
ArrowFormat::DictionaryType.new(index_type,
113+
type,
114+
red_arrow_type.ordered?)
109115
else
110116
raise "Unsupported type: #{red_arrow_type.inspect}"
111117
end
112118
end
113119

114120
def convert_field(red_arrow_field)
121+
type = convert_type(red_arrow_field.data_type)
122+
if type.is_a?(ArrowFormat::DictionaryType)
123+
@dictionary_id ||= 0
124+
dictionary_id = @dictionary_id
125+
@dictionary_id += 1
126+
else
127+
dictionary_id = nil
128+
end
115129
ArrowFormat::Field.new(red_arrow_field.name,
116-
convert_type(red_arrow_field.data_type),
130+
type,
117131
red_arrow_field.nullable?,
118-
nil)
132+
dictionary_id)
119133
end
120134

121135
def convert_buffer(buffer)
@@ -171,11 +185,33 @@ def convert_array(red_arrow_array)
171185
type.build_array(red_arrow_array.size,
172186
types_buffer,
173187
children)
188+
when ArrowFormat::DictionaryType
189+
validity_buffer = convert_buffer(red_arrow_array.null_bitmap)
190+
indices_buffer = convert_buffer(red_arrow_array.indices.data_buffer)
191+
dictionary = convert_array(red_arrow_array.dictionary)
192+
type.build_array(red_arrow_array.size,
193+
validity_buffer,
194+
indices_buffer,
195+
dictionary)
174196
else
175197
raise "Unsupported array #{red_arrow_array.inspect}"
176198
end
177199
end
178200

201+
def write(writer)
202+
red_arrow_array = build_array
203+
array = convert_array(red_arrow_array)
204+
red_arrow_field = Arrow::Field.new("value",
205+
red_arrow_array.value_data_type,
206+
true)
207+
fields = [convert_field(red_arrow_field)]
208+
schema = ArrowFormat::Schema.new(fields)
209+
record_batch = ArrowFormat::RecordBatch.new(schema, array.size, [array])
210+
writer.start(schema)
211+
writer.write_record_batch(record_batch)
212+
writer.finish
213+
end
214+
179215
class << self
180216
def included(base)
181217
base.class_eval do
@@ -939,6 +975,19 @@ def test_write
939975
@values)
940976
end
941977
end
978+
979+
sub_test_case("Dictionary") do
980+
def build_array
981+
values = ["a", "b", "c", nil, "a"]
982+
string_array = Arrow::StringArray.new(values)
983+
string_array.dictionary_encode
984+
end
985+
986+
def test_write
987+
assert_equal(["a", "b", "c", nil, "a"],
988+
@values)
989+
end
990+
end
942991
end
943992
end
944993
end
@@ -952,19 +1001,7 @@ def setup
9521001
path = File.join(tmp_dir, "data.arrow")
9531002
File.open(path, "wb") do |output|
9541003
writer = ArrowFormat::FileWriter.new(output)
955-
red_arrow_array = build_array
956-
array = convert_array(red_arrow_array)
957-
fields = [
958-
ArrowFormat::Field.new("value",
959-
array.type,
960-
true,
961-
nil),
962-
]
963-
schema = ArrowFormat::Schema.new(fields)
964-
record_batch = ArrowFormat::RecordBatch.new(schema, array.size, [array])
965-
writer.start(schema)
966-
writer.write_record_batch(record_batch)
967-
writer.finish
1004+
write(writer)
9681005
end
9691006
data = File.open(path, "rb", &:read).freeze
9701007
table = Arrow::Table.load(Arrow::Buffer.new(data), format: :arrow)
@@ -982,19 +1019,7 @@ def setup
9821019
path = File.join(tmp_dir, "data.arrows")
9831020
File.open(path, "wb") do |output|
9841021
writer = ArrowFormat::StreamingWriter.new(output)
985-
red_arrow_array = build_array
986-
array = convert_array(red_arrow_array)
987-
fields = [
988-
ArrowFormat::Field.new("value",
989-
array.type,
990-
true,
991-
nil),
992-
]
993-
schema = ArrowFormat::Schema.new(fields)
994-
record_batch = ArrowFormat::RecordBatch.new(schema, array.size, [array])
995-
writer.start(schema)
996-
writer.write_record_batch(record_batch)
997-
writer.finish
1022+
write(writer)
9981023
end
9991024
data = File.open(path, "rb", &:read).freeze
10001025
table = Arrow::Table.load(Arrow::Buffer.new(data), format: :arrows)

0 commit comments

Comments
 (0)