Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tflite/kernels/parse_example/example_proto_fast_parsing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
namespace tensorflow {
namespace example {

string ExampleName(const absl::Span<const tstring> example_names, int n) {
std::string ExampleName(const absl::Span<const tstring> example_names, int n) {
return example_names.empty() ? "<unknown>" : example_names[n];
}

Expand Down Expand Up @@ -62,19 +62,19 @@ void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
}
}

uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
uint8_t PeekTag(protobuf::io::CodedInputStream* stream) {
DCHECK(stream != nullptr);
const void* ptr;
int size;
if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
return *static_cast<const uint8*>(ptr);
return *static_cast<const uint8_t*>(ptr);
}

bool ParseString(protobuf::io::CodedInputStream* stream,
absl::string_view* result) {
DCHECK(stream != nullptr);
DCHECK(result != nullptr);
uint32 length;
uint32_t length;
if (!stream->ReadVarint32(&length)) return false;
if (length == 0) {
*result = absl::string_view(nullptr, 0);
Expand All @@ -85,7 +85,7 @@ bool ParseString(protobuf::io::CodedInputStream* stream,
if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
return false;
}
if (static_cast<uint32>(stream_size) < length) return false;
if (static_cast<uint32_t>(stream_size) < length) return false;
*result = absl::string_view(static_cast<const char*>(stream_alias), length);
stream->Skip(length);
return true;
Expand All @@ -95,7 +95,7 @@ bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
parsed::FeatureMapEntry* feature_map_entry) {
DCHECK(stream != nullptr);
DCHECK(feature_map_entry != nullptr);
uint32 length;
uint32_t length;
if (!stream->ReadVarint32(&length)) return false;
auto limit = stream->PushLimit(length);
if (!stream->ExpectTag(kDelimitedTag(1))) return false;
Expand All @@ -113,7 +113,7 @@ bool ParseFeatures(protobuf::io::CodedInputStream* stream,
parsed::Example* example) {
DCHECK(stream != nullptr);
DCHECK(example != nullptr);
uint32 length;
uint32_t length;
if (!stream->ReadVarint32(&length)) return false;
auto limit = stream->PushLimit(length);
while (!stream->ExpectAtEnd()) {
Expand Down Expand Up @@ -146,7 +146,7 @@ bool ParseExample(protobuf::io::CodedInputStream* stream,
bool ParseExample(absl::string_view serialized, parsed::Example* example) {
DCHECK(example != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
reinterpret_cast<const uint8_t*>(serialized.data()), serialized.size());
EnableAliasing(&stream);
return ParseExample(&stream, example);
}
Expand Down
86 changes: 45 additions & 41 deletions tflite/kernels/parse_example/example_proto_fast_parsing.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
template <typename A>
void EnableAliasing(A&& a) {}

uint8 PeekTag(protobuf::io::CodedInputStream* stream);
uint8_t PeekTag(protobuf::io::CodedInputStream* stream);

constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
constexpr uint8_t kVarintTag(uint32_t tag) { return (tag << 3) | 0; }
constexpr uint8_t kDelimitedTag(uint32_t tag) { return (tag << 3) | 2; }
constexpr uint8_t kFixed32Tag(uint32_t tag) { return (tag << 3) | 5; }

namespace parsed {

Expand All @@ -121,7 +121,7 @@ class Feature {
*dtype = DT_INVALID;
return absl::OkStatus();
}
uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
uint8_t oneof_tag = static_cast<uint8_t>(*serialized_.data());
serialized_.remove_prefix(1);
switch (oneof_tag) {
case kDelimitedTag(1):
Expand All @@ -143,15 +143,16 @@ class Feature {

bool GetNumElementsInBytesList(int* num_elements) {
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
reinterpret_cast<const uint8_t*>(serialized_.data()),
serialized_.size());
EnableAliasing(&stream);
uint32 length = 0;
uint32_t length = 0;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);
*num_elements = 0;
while (!stream.ExpectAtEnd()) {
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
uint32 bytes_length = 0;
uint32_t bytes_length = 0;
if (!stream.ReadVarint32(&bytes_length)) return false;
if (!stream.Skip(bytes_length)) return false;
++*num_elements;
Expand All @@ -176,18 +177,19 @@ class Feature {
DCHECK(bytes_list != nullptr);

protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
reinterpret_cast<const uint8_t*>(serialized_.data()),
serialized_.size());

EnableAliasing(&stream);

uint32 length;
uint32_t length;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);

while (!stream.ExpectAtEnd()) {
if (!stream.ExpectTag(kDelimitedTag(1))) return false;
// parse string
uint32 bytes_length;
uint32_t bytes_length;
if (!stream.ReadVarint32(&bytes_length)) return false;
tstring* bytes = construct_at_end(bytes_list);
if (bytes == nullptr) return false;
Expand All @@ -202,22 +204,23 @@ class Feature {
bool ParseFloatList(Result* float_list) {
DCHECK(float_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
reinterpret_cast<const uint8_t*>(serialized_.data()),
serialized_.size());
EnableAliasing(&stream);
uint32 length;
uint32_t length;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);

if (!stream.ExpectAtEnd()) {
uint8 peek_tag = PeekTag(&stream);
uint8_t peek_tag = PeekTag(&stream);
if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
return false;
}

constexpr int32_t kNumFloatBytes = 4;
if (peek_tag == kDelimitedTag(1)) { // packed
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
uint32 packed_length;
uint32_t packed_length;
if (!stream.ReadVarint32(&packed_length)) return false;
auto packed_limit = stream.PushLimit(packed_length);

Expand All @@ -233,16 +236,16 @@ class Feature {
sizeof(typename Result::value_type) == kNumFloatBytes) {
// Calculate the length of the buffer available what can be less than
// what we requested in resize in case of a LimitedArraySlice.
const uint32 bytes_to_copy =
std::min(static_cast<uint32>((float_list->size() - initial_size) *
kNumFloatBytes),
packed_length);
const uint32_t bytes_to_copy = std::min(
static_cast<uint32_t>((float_list->size() - initial_size) *
kNumFloatBytes),
packed_length);
if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
return false;
} else {
int64_t index = initial_size;
while (!stream.ExpectAtEnd()) {
uint32 buffer32;
uint32_t buffer32;
if (!stream.ReadLittleEndian32(&buffer32)) return false;
if (index < float_list->size()) {
float_list->data()[index] = absl::bit_cast<float>(buffer32);
Expand All @@ -262,7 +265,7 @@ class Feature {
int64_t index = initial_size;
while (!stream.ExpectAtEnd()) {
if (!stream.ExpectTag(kFixed32Tag(1))) return false;
uint32 buffer32;
uint32_t buffer32;
if (!stream.ReadLittleEndian32(&buffer32)) return false;
float_list->data()[index] = absl::bit_cast<float>(buffer32);
++index;
Expand All @@ -278,20 +281,21 @@ class Feature {
bool ParseInt64List(Result* int64_list) {
DCHECK(int64_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
reinterpret_cast<const uint8_t*>(serialized_.data()),
serialized_.size());
EnableAliasing(&stream);
uint32 length;
uint32_t length;
if (!stream.ReadVarint32(&length)) return false;
auto limit = stream.PushLimit(length);

if (!stream.ExpectAtEnd()) {
uint8 peek_tag = PeekTag(&stream);
uint8_t peek_tag = PeekTag(&stream);
if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
return false;
}
if (peek_tag == kDelimitedTag(1)) { // packed
if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
uint32 packed_length;
uint32_t packed_length;
if (!stream.ReadVarint32(&packed_length)) return false;
auto packed_limit = stream.PushLimit(packed_length);

Expand Down Expand Up @@ -327,7 +331,7 @@ using Example = std::vector<FeatureMapEntry>;
} // namespace parsed

inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
uint32 data;
uint32_t data;
protobuf_uint64 dummy;
switch (stream->ReadTag() & 0x7) {
case 0: // varint
Expand Down Expand Up @@ -387,10 +391,10 @@ struct SparseBuffer {
};

struct SeededHasher {
uint64 operator()(absl::string_view s) const {
uint64_t operator()(absl::string_view s) const {
return Hash64(s.data(), s.size(), seed);
}
uint64 seed{0xDECAFCAFFE};
uint64_t seed{0xDECAFCAFFE};
};

// Use this in the "default" clause of switch statements when dispatching
Expand Down Expand Up @@ -451,21 +455,21 @@ struct FeatureProtos {
// Map from feature name to FeatureProtos for that feature.
using FeatureProtosMap = absl::flat_hash_map<absl::string_view, FeatureProtos>;

string ExampleName(const absl::Span<const tstring> example_names, int n);
std::string ExampleName(const absl::Span<const tstring> example_names, int n);

// Return the number of bytes elements parsed, or -1 on error. If out is null,
// this method simply counts the number of elements without any copying.
inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
tstring* out) {
int num_elements = 0;
uint32 length;
uint32_t length;
if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
return -1;
}
if (length > 0) {
auto limit = stream->PushLimit(length);
while (!stream->ExpectAtEnd()) {
uint32 bytes_length;
uint32_t bytes_length;
if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&bytes_length)) {
return -1;
Expand Down Expand Up @@ -503,22 +507,22 @@ inline void PadInt64Feature(int num_to_pad, int64_t* out) {
inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
float* out) {
int num_elements = 0;
uint32 length;
uint32_t length;
if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
return -1;
}
if (length > 0) {
auto limit = stream->PushLimit(length);
uint8 peek_tag = PeekTag(stream);
uint8_t peek_tag = PeekTag(stream);
if (peek_tag == kDelimitedTag(1)) { // packed
uint32 packed_length;
uint32_t packed_length;
if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&packed_length)) {
return -1;
}
auto packed_limit = stream->PushLimit(packed_length);
while (!stream->ExpectAtEnd()) {
uint32 buffer32;
uint32_t buffer32;
if (!stream->ReadLittleEndian32(&buffer32)) {
return -1;
}
Expand All @@ -530,7 +534,7 @@ inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
stream->PopLimit(packed_limit);
} else if (peek_tag == kFixed32Tag(1)) {
while (!stream->ExpectAtEnd()) {
uint32 buffer32;
uint32_t buffer32;
if (!stream->ExpectTag(kFixed32Tag(1)) ||
!stream->ReadLittleEndian32(&buffer32)) {
return -1;
Expand All @@ -554,15 +558,15 @@ inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
int64_t* out) {
int num_elements = 0;
uint32 length;
uint32_t length;
if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
return -1;
}
if (length > 0) {
auto limit = stream->PushLimit(length);
uint8 peek_tag = PeekTag(stream);
uint8_t peek_tag = PeekTag(stream);
if (peek_tag == kDelimitedTag(1)) { // packed
uint32 packed_length;
uint32_t packed_length;
if (!stream->ExpectTag(kDelimitedTag(1)) ||
!stream->ReadVarint32(&packed_length)) {
return -1;
Expand Down Expand Up @@ -646,7 +650,7 @@ inline int GetFeatureLength(DataType dtype,
}

inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
uint8 peek_tag = PeekTag(stream);
uint8_t peek_tag = PeekTag(stream);
switch (peek_tag) {
case kDelimitedTag(1):
return DT_STRING;
Expand Down Expand Up @@ -680,7 +684,7 @@ inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
default:
return false;
}
uint32 length;
uint32_t length;
return stream->ReadVarint32(&length) && length == 0;
}

Expand Down
Loading