Skip to content

Commit 045acda

Browse files
committed
Combine post processing and initialization
It's going to be useful for Any support
1 parent a268b67 commit 045acda

File tree

2 files changed

+68
-79
lines changed

2 files changed

+68
-79
lines changed

src/mutator.cc

Lines changed: 65 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,67 @@ class DataSourceSampler {
355355
WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
356356
};
357357

358+
class PostProcessing {
359+
public:
360+
using PostProcessors = std::unordered_multimap<const protobuf::Descriptor*,
361+
Mutator::PostProcess>;
362+
363+
PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
364+
RandomEngine* random)
365+
: keep_initialized_(keep_initialized),
366+
post_processors_(post_processors),
367+
random_(random) {}
368+
369+
void Run(Message* message, int max_depth) {
370+
--max_depth;
371+
const Descriptor* descriptor = message->GetDescriptor();
372+
373+
// Apply custom mutators in nested messages before packing any.
374+
const Reflection* reflection = message->GetReflection();
375+
for (int i = 0; i < descriptor->field_count(); i++) {
376+
const FieldDescriptor* field = descriptor->field(i);
377+
if (keep_initialized_ &&
378+
(field->is_required() || descriptor->options().map_entry()) &&
379+
!reflection->HasField(*message, field)) {
380+
CreateDefaultField()(FieldInstance(message, field));
381+
}
382+
383+
if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;
384+
385+
if (max_depth < 0 && !field->is_required()) {
386+
// Clear deep optional fields to avoid stack overflow.
387+
reflection->ClearField(message, field);
388+
if (field->is_repeated())
389+
assert(!reflection->FieldSize(*message, field));
390+
else
391+
assert(!reflection->HasField(*message, field));
392+
continue;
393+
}
394+
395+
if (field->is_repeated()) {
396+
const int field_size = reflection->FieldSize(*message, field);
397+
for (int j = 0; j < field_size; ++j) {
398+
Message* nested_message =
399+
reflection->MutableRepeatedMessage(message, field, j);
400+
Run(nested_message, max_depth);
401+
}
402+
} else if (reflection->HasField(*message, field)) {
403+
Message* nested_message = reflection->MutableMessage(message, field);
404+
Run(nested_message, max_depth);
405+
}
406+
}
407+
408+
auto range = post_processors_.equal_range(descriptor);
409+
for (auto it = range.first; it != range.second; ++it)
410+
it->second(message, (*random_)());
411+
}
412+
413+
private:
414+
bool keep_initialized_;
415+
const PostProcessors& post_processors_;
416+
RandomEngine* random_;
417+
};
418+
358419
} // namespace
359420

360421
class FieldMutator {
@@ -479,47 +540,16 @@ void Mutator::Mutate(Message* message, size_t max_size_hint) {
479540
static_cast<int>(max_size_hint) -
480541
static_cast<int>(message->ByteSizeLong()));
481542

482-
InitializeAndTrim(message, kMaxInitializeDepth);
543+
PostProcessing(keep_initialized_, post_processors_, &random_)
544+
.Run(message, kMaxInitializeDepth);
483545
assert(IsInitialized(*message));
484-
485-
if (!post_processors_.empty()) {
486-
ApplyPostProcessing(message);
487-
}
488546
}
489547

490548
void Mutator::RegisterPostProcessor(const Descriptor* desc,
491549
PostProcess callback) {
492550
post_processors_.emplace(desc, callback);
493551
}
494552

495-
void Mutator::ApplyPostProcessing(Message* message) {
496-
const Descriptor* descriptor = message->GetDescriptor();
497-
498-
auto range = post_processors_.equal_range(descriptor);
499-
for (auto it = range.first; it != range.second; ++it)
500-
it->second(message, random_());
501-
502-
// Now recursively apply custom mutators.
503-
const Reflection* reflection = message->GetReflection();
504-
for (int i = 0; i < descriptor->field_count(); i++) {
505-
const FieldDescriptor* field = descriptor->field(i);
506-
if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
507-
continue;
508-
}
509-
if (field->is_repeated()) {
510-
const int field_size = reflection->FieldSize(*message, field);
511-
for (int j = 0; j < field_size; ++j) {
512-
Message* nested_message =
513-
reflection->MutableRepeatedMessage(message, field, j);
514-
ApplyPostProcessing(nested_message);
515-
}
516-
} else if (reflection->HasField(*message, field)) {
517-
Message* nested_message = reflection->MutableMessage(message, field);
518-
ApplyPostProcessing(nested_message);
519-
}
520-
}
521-
}
522-
523553
bool Mutator::MutateImpl(const Message& source, Message* message,
524554
bool copy_clone_only, int size_increase_hint) {
525555
if (size_increase_hint > 0) size_increase_hint /= 2;
@@ -578,49 +608,9 @@ void Mutator::CrossOver(const Message& message1, Message* message2,
578608
MutateImpl(message1, message2, true, size_increase_hint) ||
579609
MutateImpl(*message2, message2, true, size_increase_hint);
580610

581-
InitializeAndTrim(message2, kMaxInitializeDepth);
611+
PostProcessing(keep_initialized_, post_processors_, &random_)
612+
.Run(message2, kMaxInitializeDepth);
582613
assert(IsInitialized(*message2));
583-
584-
if (!post_processors_.empty()) {
585-
ApplyPostProcessing(message2);
586-
}
587-
}
588-
589-
void Mutator::InitializeAndTrim(Message* message, int max_depth) {
590-
const Descriptor* descriptor = message->GetDescriptor();
591-
const Reflection* reflection = message->GetReflection();
592-
for (int i = 0; i < descriptor->field_count(); ++i) {
593-
const FieldDescriptor* field = descriptor->field(i);
594-
if (keep_initialized_ &&
595-
(field->is_required() || descriptor->options().map_entry()) &&
596-
!reflection->HasField(*message, field)) {
597-
CreateDefaultField()(FieldInstance(message, field));
598-
}
599-
600-
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
601-
if (max_depth <= 0 && !field->is_required()) {
602-
// Clear deep optional fields to avoid stack overflow.
603-
reflection->ClearField(message, field);
604-
if (field->is_repeated())
605-
assert(!reflection->FieldSize(*message, field));
606-
else
607-
assert(!reflection->HasField(*message, field));
608-
continue;
609-
}
610-
611-
if (field->is_repeated()) {
612-
const int field_size = reflection->FieldSize(*message, field);
613-
for (int j = 0; j < field_size; ++j) {
614-
Message* nested_message =
615-
reflection->MutableRepeatedMessage(message, field, j);
616-
InitializeAndTrim(nested_message, max_depth - 1);
617-
}
618-
} else if (reflection->HasField(*message, field)) {
619-
Message* nested_message = reflection->MutableMessage(message, field);
620-
InitializeAndTrim(nested_message, max_depth - 1);
621-
}
622-
}
623-
}
624614
}
625615

626616
int32_t Mutator::MutateInt32(int32_t value) { return FlipBit(value, &random_); }

src/mutator.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,17 @@ class Mutator {
8989
private:
9090
friend class FieldMutator;
9191
friend class TestMutator;
92-
void InitializeAndTrim(protobuf::Message* message, int max_depth);
9392
bool MutateImpl(const protobuf::Message& source, protobuf::Message* message,
9493
bool copy_clone_only, int size_increase_hint);
9594
std::string MutateUtf8String(const std::string& value,
9695
int size_increase_hint);
97-
void ApplyPostProcessing(protobuf::Message* message);
9896
bool IsInitialized(const protobuf::Message& message) const;
9997
bool keep_initialized_ = true;
10098
size_t random_to_default_ratio_ = 100;
10199
RandomEngine random_;
102-
std::unordered_multimap<const protobuf::Descriptor*, PostProcess>
103-
post_processors_;
100+
using PostProcessors =
101+
std::unordered_multimap<const protobuf::Descriptor*, PostProcess>;
102+
PostProcessors post_processors_;
104103
};
105104

106105
} // namespace protobuf_mutator

0 commit comments

Comments
 (0)