|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.arrow.adapter.avro; |
| 19 | + |
| 20 | +import org.apache.arrow.adapter.avro.consumers.CompositeAvroConsumer; |
| 21 | +import org.apache.arrow.memory.BufferAllocator; |
| 22 | +import org.apache.arrow.vector.FieldVector; |
| 23 | +import org.apache.arrow.vector.VectorSchemaRoot; |
| 24 | +import org.apache.arrow.vector.dictionary.Dictionary; |
| 25 | +import org.apache.arrow.vector.dictionary.DictionaryProvider; |
| 26 | +import org.apache.arrow.vector.types.pojo.Schema; |
| 27 | +import org.apache.avro.file.DataFileConstants; |
| 28 | +import org.apache.avro.io.BinaryData; |
| 29 | +import org.apache.avro.io.BinaryDecoder; |
| 30 | +import org.apache.avro.io.DecoderFactory; |
| 31 | + |
| 32 | +import java.io.*; |
| 33 | +import java.nio.ByteBuffer; |
| 34 | +import java.nio.charset.StandardCharsets; |
| 35 | +import java.util.ArrayList; |
| 36 | +import java.util.List; |
| 37 | +import java.util.Set; |
| 38 | + |
| 39 | + |
| 40 | +class AvroFileReader implements DictionaryProvider { |
| 41 | + |
| 42 | + // Writer owns a channel / decoder and will close them |
| 43 | + // Schema / VSR / dictionaries are created when header is read |
| 44 | + // VSR / dictionaries are cleaned up on close |
| 45 | + // Dictionaries accessible through DictionaryProvider iface |
| 46 | + |
| 47 | + // Use magic from Avro's own constants |
| 48 | + private static final byte[] AVRO_MAGIC = DataFileConstants.MAGIC; |
| 49 | + private static final int SYNC_MARKER_SIZE = 16; |
| 50 | + |
| 51 | + private final InputStream stream; |
| 52 | + private final BinaryDecoder decoder; |
| 53 | + private final BufferAllocator allocator; |
| 54 | + private final boolean blocking; |
| 55 | + |
| 56 | + private org.apache.avro.Schema avroSchema; |
| 57 | + private String avroCodec; |
| 58 | + private final byte[] syncMarker; |
| 59 | + |
| 60 | + private CompositeAvroConsumer recordConsumer; |
| 61 | + private VectorSchemaRoot arrowBatch; |
| 62 | + private Schema arrowSchema; |
| 63 | + private DictionaryProvider.MapDictionaryProvider dictionaries; |
| 64 | + |
| 65 | + private long nextBatchPosition; |
| 66 | + private ByteBuffer batchBuffer; |
| 67 | + private BinaryDecoder batchDecoder; |
| 68 | + private final byte[] batchSyncMarker; |
| 69 | + |
| 70 | + // Create a new AvroFileReader for the input stream |
| 71 | + // In order to support non-blocking mode, the stream must support mark / reset |
| 72 | + public AvroFileReader( |
| 73 | + InputStream stream, |
| 74 | + BufferAllocator allocator, |
| 75 | + boolean blocking) { |
| 76 | + |
| 77 | + this.stream =stream; |
| 78 | + this.allocator = allocator; |
| 79 | + this.blocking = blocking; |
| 80 | + |
| 81 | + if (blocking) { |
| 82 | + this.decoder = DecoderFactory.get().binaryDecoder(stream, null); |
| 83 | + } else { |
| 84 | + if (!stream.markSupported()) { |
| 85 | + throw new IllegalArgumentException("Input stream must support mark/reset for non-blocking mode"); |
| 86 | + } |
| 87 | + this.decoder = DecoderFactory.get().directBinaryDecoder(stream, null); |
| 88 | + } |
| 89 | + |
| 90 | + this.syncMarker = new byte[SYNC_MARKER_SIZE]; |
| 91 | + this.batchSyncMarker = new byte[SYNC_MARKER_SIZE]; |
| 92 | + } |
| 93 | + |
| 94 | + // Read the Avro header and set up schema / VSR / dictionaries |
| 95 | + void readHeader() throws IOException { |
| 96 | + |
| 97 | + if (avroSchema != null) { |
| 98 | + throw new IllegalStateException("Avro header has already been read"); |
| 99 | + } |
| 100 | + |
| 101 | + // Keep track of the header size |
| 102 | + long headerSize = 0; |
| 103 | + |
| 104 | + // Read Avro magic |
| 105 | + byte[] magic = new byte[AVRO_MAGIC.length]; |
| 106 | + decoder.readFixed(magic); |
| 107 | + headerSize += magic.length; |
| 108 | + |
| 109 | + // Validate Avro magic |
| 110 | + int validateMagic = BinaryData.compareBytes( |
| 111 | + AVRO_MAGIC, 0, AVRO_MAGIC.length, |
| 112 | + magic, 0, AVRO_MAGIC.length); |
| 113 | + |
| 114 | + if (validateMagic != 0) { |
| 115 | + throw new RuntimeException("Invalid AVRO data file: The file is not an Avro file"); |
| 116 | + } |
| 117 | + |
| 118 | + // Read the metadata map |
| 119 | + for (long count = decoder.readMapStart(); count != 0; count = decoder.mapNext()) { |
| 120 | + |
| 121 | + headerSize += zigzagSize(count); |
| 122 | + |
| 123 | + for (long i = 0; i < count; i++) { |
| 124 | + |
| 125 | + ByteBuffer keyBuffer = decoder.readBytes(null); |
| 126 | + ByteBuffer valueBuffer = decoder.readBytes(null); |
| 127 | + |
| 128 | + headerSize += zigzagSize(keyBuffer.remaining()) + keyBuffer.remaining(); |
| 129 | + headerSize += zigzagSize(valueBuffer.remaining()) + valueBuffer.remaining(); |
| 130 | + |
| 131 | + String key = new String(keyBuffer.array(), StandardCharsets.UTF_8); |
| 132 | + |
| 133 | + // Handle header entries for schema and codec |
| 134 | + if ("avro.schema".equals(key)) { |
| 135 | + avroSchema = processSchema(valueBuffer); |
| 136 | + } else if ("avro.codec".equals(key)) { |
| 137 | + avroCodec = processCodec(valueBuffer); |
| 138 | + } |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + // End of map marker |
| 143 | + headerSize += 1; |
| 144 | + |
| 145 | + // Sync marker denotes end of the header |
| 146 | + decoder.readFixed(syncMarker); |
| 147 | + headerSize += syncMarker.length; |
| 148 | + |
| 149 | + // Schema must always be present |
| 150 | + if (avroSchema == null) { |
| 151 | + throw new RuntimeException("Invalid AVRO data file: Schema missing in file header"); |
| 152 | + } |
| 153 | + |
| 154 | + // Prepare read config |
| 155 | + this.dictionaries = new DictionaryProvider.MapDictionaryProvider(); |
| 156 | + AvroToArrowConfig config = new AvroToArrowConfig(allocator, 0, dictionaries, Set.of(), false); |
| 157 | + |
| 158 | + // Calling this method will also populate the dictionary map |
| 159 | + this.recordConsumer = AvroToArrowUtils.createCompositeConsumer(avroSchema, config); |
| 160 | + |
| 161 | + // Initialize data vectors |
| 162 | + List<FieldVector> vectors = new ArrayList<>(arrowSchema.getFields().size()); |
| 163 | + for (int i = 0; i < arrowSchema.getFields().size(); i++) { |
| 164 | + FieldVector vector = recordConsumer.getConsumers().get(i).getVector(); |
| 165 | + vectors.add(vector); |
| 166 | + } |
| 167 | + |
| 168 | + // Initialize batch and schema |
| 169 | + this.arrowBatch = new VectorSchemaRoot(vectors); |
| 170 | + this.arrowSchema = arrowBatch.getSchema(); |
| 171 | + |
| 172 | + // First batch starts after the header |
| 173 | + this.nextBatchPosition = headerSize; |
| 174 | + } |
| 175 | + |
| 176 | + private org.apache.avro.Schema processSchema(ByteBuffer buffer) throws IOException { |
| 177 | + |
| 178 | + org.apache.avro.Schema.Parser parser = new org.apache.avro.Schema.Parser(); |
| 179 | + |
| 180 | + try (InputStream schemaStream = new ByteArrayInputStream(buffer.array())) { |
| 181 | + return parser.parse(schemaStream); |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + private String processCodec(ByteBuffer buffer) { |
| 186 | + |
| 187 | + if (buffer != null && buffer.remaining() > 0) { |
| 188 | + return new String(buffer.array(), StandardCharsets.UTF_8); |
| 189 | + } |
| 190 | + else { |
| 191 | + return DataFileConstants.NULL_CODEC; |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + // Schema and VSR available after readHeader() |
| 196 | + Schema getSchema() { |
| 197 | + if (avroSchema == null) { |
| 198 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 199 | + } |
| 200 | + return arrowSchema; |
| 201 | + } |
| 202 | + |
| 203 | + VectorSchemaRoot getVectorSchemaRoot() { |
| 204 | + if (avroSchema == null) { |
| 205 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 206 | + } |
| 207 | + return arrowBatch; |
| 208 | + } |
| 209 | + |
| 210 | + @Override |
| 211 | + public Set<Long> getDictionaryIds() { |
| 212 | + if (avroSchema == null) { |
| 213 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 214 | + } |
| 215 | + return dictionaries.getDictionaryIds(); |
| 216 | + } |
| 217 | + |
| 218 | + @Override |
| 219 | + public Dictionary lookup(long id) { |
| 220 | + if (avroSchema == null) { |
| 221 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 222 | + } |
| 223 | + return dictionaries.lookup(id); |
| 224 | + } |
| 225 | + |
| 226 | + // Read the next Avro block and load it into the VSR |
| 227 | + // Return true if successful, false if EOS |
| 228 | + // Also false in non-blocking mode if need more data |
| 229 | + boolean readBatch() throws IOException { |
| 230 | + |
| 231 | + if (avroSchema == null) { |
| 232 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 233 | + } |
| 234 | + |
| 235 | + if (!hasNextBatch()) { |
| 236 | + return false; |
| 237 | + } |
| 238 | + |
| 239 | + // Read Avro block from the main encoder |
| 240 | + long nRows = decoder.readLong(); |
| 241 | + batchBuffer = decoder.readBytes(batchBuffer); |
| 242 | + decoder.readFixed(batchSyncMarker); |
| 243 | + |
| 244 | + // Validate sync marker - mismatch indicates a corrupt file |
| 245 | + long batchSize = |
| 246 | + zigzagSize(nRows) + |
| 247 | + zigzagSize(batchBuffer.remaining()) + |
| 248 | + batchBuffer.remaining() + |
| 249 | + SYNC_MARKER_SIZE; |
| 250 | + |
| 251 | + int validateMarker = BinaryData.compareBytes( |
| 252 | + syncMarker, 0, SYNC_MARKER_SIZE, |
| 253 | + batchSyncMarker, 0, SYNC_MARKER_SIZE); |
| 254 | + |
| 255 | + if (validateMarker != 0) { |
| 256 | + throw new RuntimeException("Invalid AVRO data file: The file is corrupted"); |
| 257 | + } |
| 258 | + |
| 259 | + // Reset producers |
| 260 | + recordConsumer.getConsumers().forEach(consumer -> ensureCapacity(consumer.getVector(), (int) nRows)); |
| 261 | + recordConsumer.getConsumers().forEach(consumer -> consumer.setPosition(0)); |
| 262 | + |
| 263 | + // Decompress the batch buffer using Avro's codecs |
| 264 | + var codec = AvroCompression.getAvroCodec(avroCodec); |
| 265 | + var batchDecompressed = codec.decompress(batchBuffer); |
| 266 | + |
| 267 | + // Prepare batch stream and decoder |
| 268 | + try (InputStream batchStream = new ByteArrayInputStream(batchDecompressed.array())) { |
| 269 | + |
| 270 | + batchDecoder = DecoderFactory.get().directBinaryDecoder(batchStream, batchDecoder); |
| 271 | + |
| 272 | + // Consume a batch, reading from the batch stream (buffer) |
| 273 | + for (int row = 0; row < nRows; row++) { |
| 274 | + recordConsumer.consume(batchDecoder); |
| 275 | + } |
| 276 | + |
| 277 | + arrowBatch.setRowCount((int) nRows); |
| 278 | + |
| 279 | + // Update next batch position |
| 280 | + nextBatchPosition += batchSize; |
| 281 | + |
| 282 | + // Batch is ready |
| 283 | + return true; |
| 284 | + } |
| 285 | + } |
| 286 | + |
| 287 | + private void ensureCapacity(FieldVector vector, int capacity) { |
| 288 | + if (vector.getValueCapacity() < capacity) { |
| 289 | + vector.setInitialCapacity(capacity); |
| 290 | + } |
| 291 | + } |
| 292 | + |
| 293 | + // Check for position and size of the next Avro data block |
| 294 | + // Provides a mechanism for non-blocking / reactive styles |
| 295 | + boolean hasNextBatch() throws IOException { |
| 296 | + |
| 297 | + if (avroSchema == null) { |
| 298 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 299 | + } |
| 300 | + |
| 301 | + if (blocking) { |
| 302 | + return ! decoder.isEnd(); |
| 303 | + } |
| 304 | + |
| 305 | + var in = decoder.inputStream(); |
| 306 | + in.mark(1); |
| 307 | + |
| 308 | + try { |
| 309 | + |
| 310 | + int nextByte = in.read(); |
| 311 | + in.reset(); |
| 312 | + |
| 313 | + return nextByte >= 0; |
| 314 | + } |
| 315 | + catch(EOFException e) { |
| 316 | + return false; |
| 317 | + } |
| 318 | + } |
| 319 | + |
| 320 | + long nextBatchPosition() { |
| 321 | + |
| 322 | + if (avroSchema == null) { |
| 323 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 324 | + } |
| 325 | + |
| 326 | + return nextBatchPosition; |
| 327 | + } |
| 328 | + |
| 329 | + public long nextBatchSize() throws IOException { |
| 330 | + |
| 331 | + if (avroSchema == null) { |
| 332 | + throw new IllegalStateException("Avro header has not been read yet"); |
| 333 | + } |
| 334 | + |
| 335 | + if (blocking) { |
| 336 | + throw new IllegalStateException("Next batch size is only available in non-blocking mode"); |
| 337 | + } |
| 338 | + |
| 339 | + InputStream in = decoder.inputStream(); |
| 340 | + in.mark(20); |
| 341 | + |
| 342 | + long nRows = decoder.readLong(); |
| 343 | + long nBytes = decoder.readLong(); |
| 344 | + |
| 345 | + in.reset(); |
| 346 | + |
| 347 | + return zigzagSize(nRows) + zigzagSize(nBytes) + nBytes + SYNC_MARKER_SIZE; |
| 348 | + } |
| 349 | + |
| 350 | + private int zigzagSize(long n) { |
| 351 | + |
| 352 | + long val = (n << 1) ^ (n >> 63); // move sign to low-order bit |
| 353 | + int bytes = 1; |
| 354 | + |
| 355 | + while ((val & ~0x7F) != 0) { |
| 356 | + bytes += 1; |
| 357 | + val >>>= 7; |
| 358 | + } |
| 359 | + |
| 360 | + return bytes; |
| 361 | + } |
| 362 | + |
| 363 | + // Closes encoder and / or channel |
| 364 | + // Also closes VSR and dictionary vectors |
| 365 | + void close() throws IOException { |
| 366 | + |
| 367 | + stream.close(); |
| 368 | + |
| 369 | + if (arrowBatch != null) { |
| 370 | + arrowBatch.close(); |
| 371 | + } |
| 372 | + |
| 373 | + if (dictionaries != null) { |
| 374 | + for (long dictionaryId : dictionaries.getDictionaryIds()) { |
| 375 | + Dictionary dictionary = dictionaries.lookup(dictionaryId); |
| 376 | + dictionary.getVector().close(); |
| 377 | + } |
| 378 | + } |
| 379 | + } |
| 380 | +} |
0 commit comments