|
14 | 14 |
|
15 | 15 | #include "paddle/fluid/recordio/chunk.h"
|
16 | 16 |
|
17 |
| -#include <cstring> |
| 17 | +#include <memory> |
18 | 18 | #include <sstream>
|
19 |
| -#include <utility> |
20 |
| - |
21 |
| -#include "snappy.h" |
22 |
| - |
23 |
| -#include "paddle/fluid/recordio/crc32.h" |
| 19 | +#include "paddle/fluid/platform/enforce.h" |
| 20 | +#include "snappystream.hpp" |
| 21 | +#include "zlib.h" |
24 | 22 |
|
25 | 23 | namespace paddle {
|
26 | 24 | namespace recordio {
|
| 25 | +constexpr size_t kMaxBufSize = 1024; |
27 | 26 |
|
28 |
| -void Chunk::Add(const char* record, size_t length) { |
29 |
| - records_.emplace_after(std::string(record, length)); |
30 |
| - num_bytes_ += s.size() * sizeof(char); |
| 27 | +template <typename Callback> |
| 28 | +static void ReadStreamByBuf(std::istream& in, int limit, Callback callback) { |
| 29 | + char buf[kMaxBufSize]; |
| 30 | + std::streamsize actual_size; |
| 31 | + size_t counter = 0; |
| 32 | + do { |
| 33 | + auto actual_max = |
| 34 | + limit > 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize; |
| 35 | + actual_size = in.readsome(buf, actual_max); |
| 36 | + if (actual_size == 0) { |
| 37 | + break; |
| 38 | + } |
| 39 | + callback(buf, actual_size); |
| 40 | + if (limit > 0) { |
| 41 | + counter += actual_size; |
| 42 | + } |
| 43 | + } while (actual_size == kMaxBufSize); |
31 | 44 | }
|
32 | 45 |
|
33 |
| -bool Chunk::Dump(Stream* fo, Compressor ct) { |
| 46 | +static void PipeStream(std::istream& in, std::ostream& os) { |
| 47 | + ReadStreamByBuf( |
| 48 | + in, -1, [&os](const char* buf, size_t len) { os.write(buf, len); }); |
| 49 | +} |
| 50 | +static uint32_t Crc32Stream(std::istream& in, int limit = -1) { |
| 51 | + auto crc = crc32(0, nullptr, 0); |
| 52 | + ReadStreamByBuf(in, limit, [&crc](const char* buf, size_t len) { |
| 53 | + crc = crc32(crc, reinterpret_cast<const Bytef*>(buf), len); |
| 54 | + }); |
| 55 | + return crc; |
| 56 | +} |
| 57 | + |
| 58 | +bool Chunk::Write(std::ostream& os, Compressor ct) const { |
34 | 59 | // NOTE(dzhwinter): don't check records.numBytes instead, because
|
35 | 60 | // empty records are allowed.
|
36 |
| - if (records_.size() == 0) return false; |
| 61 | + if (records_.empty()) { |
| 62 | + return false; |
| 63 | + } |
| 64 | + std::stringstream sout; |
| 65 | + std::unique_ptr<std::ostream> compressed_stream; |
| 66 | + switch (ct) { |
| 67 | + case Compressor::kNoCompress: |
| 68 | + break; |
| 69 | + case Compressor::kSnappy: |
| 70 | + compressed_stream.reset(new snappy::oSnappyStream(sout)); |
| 71 | + break; |
| 72 | + default: |
| 73 | + PADDLE_THROW("Not implemented"); |
| 74 | + } |
| 75 | + |
| 76 | + std::ostream& buf_stream = compressed_stream ? *compressed_stream : sout; |
37 | 77 |
|
38 |
| - // pack the record into consecutive memory for compress |
39 |
| - std::ostringstream os; |
40 | 78 | for (auto& record : records_) {
|
41 |
| - os.write(record.size(), sizeof(size_t)); |
42 |
| - os.write(record.data(), static_cast<std::streamsize>(record.size())); |
| 79 | + size_t sz = record.size(); |
| 80 | + buf_stream.write(reinterpret_cast<const char*>(&sz), sizeof(uint32_t)) |
| 81 | + .write(record.data(), record.size()); |
43 | 82 | }
|
44 | 83 |
|
45 |
| - std::unique_ptr<char[]> buffer(new char[num_bytes_]); |
46 |
| - size_t compressed = |
47 |
| - CompressData(os.str().c_str(), num_bytes_, ct, buffer.get()); |
48 |
| - uint32_t checksum = Crc32(buffer.get(), compressed); |
49 |
| - Header hdr(records_.size(), checksum, ct, static_cast<uint32_t>(compressed)); |
50 |
| - hdr.Write(fo); |
51 |
| - fo.Write(buffer.get(), compressed); |
52 |
| - // clear the content |
53 |
| - records_.clear(); |
54 |
| - num_bytes_ = 0; |
| 84 | + if (compressed_stream) { |
| 85 | + compressed_stream.reset(); |
| 86 | + } |
| 87 | + |
| 88 | + auto end_pos = sout.tellg(); |
| 89 | + sout.seekg(0, std::ios::beg); |
| 90 | + uint32_t len = static_cast<uint32_t>(end_pos - sout.tellg()); |
| 91 | + uint32_t crc = Crc32Stream(sout); |
| 92 | + sout.seekg(0, std::ios::beg); |
| 93 | + |
| 94 | + Header hdr(static_cast<uint32_t>(records_.size()), crc, ct, len); |
| 95 | + hdr.Write(os); |
| 96 | + PipeStream(sout, os); |
55 | 97 | return true;
|
56 | 98 | }
|
57 | 99 |
|
58 |
| -void Chunk::Parse(Stream* fi, size_t offset) { |
59 |
| - fi->Seek(offset); |
| 100 | +void Chunk::Parse(std::istream& sin) { |
60 | 101 | Header hdr;
|
61 |
| - hdr.Parse(fi); |
62 |
| - |
63 |
| - size_t size = static_cast<size_t>(hdr.CompressSize()); |
64 |
| - std::unique_ptr<char[]> buffer(new char[size]); |
65 |
| - fi->Read(buffer.get(), size); |
66 |
| - size_t deflated_size = 0; |
67 |
| - snappy::GetUncompressedLength(buffer.get(), size, &deflated_size); |
68 |
| - std::unique_ptr<char[]> deflated_buffer(new char[deflated_size]); |
69 |
| - DeflateData(buffer.get(), size, hdr.CompressType(), deflated_buffer.get()); |
70 |
| - std::istringstream deflated( |
71 |
| - std::string(deflated_buffer.get(), deflated_size)); |
72 |
| - for (size_t i = 0; i < hdr.NumRecords(); ++i) { |
73 |
| - size_t rs; |
74 |
| - deflated.read(&rs, sizeof(size_t)); |
75 |
| - std::string record(rs, '\0'); |
76 |
| - deflated.read(&record[0], rs); |
77 |
| - records_.emplace_back(record); |
78 |
| - num_bytes_ += record.size(); |
79 |
| - } |
80 |
| -} |
| 102 | + hdr.Parse(sin); |
| 103 | + auto beg_pos = sin.tellg(); |
| 104 | + auto crc = Crc32Stream(sin, hdr.CompressSize()); |
| 105 | + PADDLE_ENFORCE_EQ(hdr.Checksum(), crc); |
81 | 106 |
|
82 |
| -size_t CompressData(const char* in, |
83 |
| - size_t in_length, |
84 |
| - Compressor ct, |
85 |
| - char* out) { |
86 |
| - size_t compressd_size = 0; |
87 |
| - switch (ct) { |
| 107 | + Clear(); |
| 108 | + |
| 109 | + sin.seekg(beg_pos, std::ios::beg); |
| 110 | + std::unique_ptr<std::istream> compressed_stream; |
| 111 | + switch (hdr.CompressType()) { |
88 | 112 | case Compressor::kNoCompress:
|
89 |
| - // do nothing |
90 |
| - memcpy(out, in, in_length); |
91 |
| - compressd_size = in_length; |
92 | 113 | break;
|
93 | 114 | case Compressor::kSnappy:
|
94 |
| - snappy::RawCompress(in, in_length, out, &compressd_size); |
| 115 | + compressed_stream.reset(new snappy::iSnappyStream(sin)); |
95 | 116 | break;
|
| 117 | + default: |
| 118 | + PADDLE_THROW("Not implemented"); |
96 | 119 | }
|
97 |
| - return compressd_size; |
98 |
| -} |
99 | 120 |
|
100 |
| -void DeflateData(const char* in, size_t in_length, Compressor ct, char* out) { |
101 |
| - switch (c) { |
102 |
| - case Compressor::kNoCompress: |
103 |
| - memcpy(out, in, in_length); |
104 |
| - break; |
105 |
| - case Compressor::kSnappy: |
106 |
| - snappy::RawUncompress(in, in_length, out); |
107 |
| - break; |
| 121 | + std::istream& stream = compressed_stream ? *compressed_stream : sin; |
| 122 | + |
| 123 | + for (uint32_t i = 0; i < hdr.NumRecords(); ++i) { |
| 124 | + uint32_t rec_len; |
| 125 | + stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t)); |
| 126 | + std::string buf; |
| 127 | + buf.resize(rec_len); |
| 128 | + stream.read(&buf[0], rec_len); |
| 129 | + Add(buf); |
108 | 130 | }
|
109 | 131 | }
|
110 | 132 |
|
|
0 commit comments