@@ -355,6 +355,67 @@ class DataSourceSampler {
355
355
WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
356
356
};
357
357
358
+ class PostProcessing {
359
+ public:
360
+ using PostProcessors = std::unordered_multimap<const protobuf::Descriptor*,
361
+ Mutator::PostProcess>;
362
+
363
+ PostProcessing (bool keep_initialized, const PostProcessors& post_processors,
364
+ RandomEngine* random)
365
+ : keep_initialized_(keep_initialized),
366
+ post_processors_ (post_processors),
367
+ random_(random) {}
368
+
369
+ void Run (Message* message, int max_depth) {
370
+ --max_depth;
371
+ const Descriptor* descriptor = message->GetDescriptor ();
372
+
373
+ // Apply custom mutators in nested messages before packing any.
374
+ const Reflection* reflection = message->GetReflection ();
375
+ for (int i = 0 ; i < descriptor->field_count (); i++) {
376
+ const FieldDescriptor* field = descriptor->field (i);
377
+ if (keep_initialized_ &&
378
+ (field->is_required () || descriptor->options ().map_entry ()) &&
379
+ !reflection->HasField (*message, field)) {
380
+ CreateDefaultField ()(FieldInstance (message, field));
381
+ }
382
+
383
+ if (field->cpp_type () != FieldDescriptor::CPPTYPE_MESSAGE) continue ;
384
+
385
+ if (max_depth < 0 && !field->is_required ()) {
386
+ // Clear deep optional fields to avoid stack overflow.
387
+ reflection->ClearField (message, field);
388
+ if (field->is_repeated ())
389
+ assert (!reflection->FieldSize (*message, field));
390
+ else
391
+ assert (!reflection->HasField (*message, field));
392
+ continue ;
393
+ }
394
+
395
+ if (field->is_repeated ()) {
396
+ const int field_size = reflection->FieldSize (*message, field);
397
+ for (int j = 0 ; j < field_size; ++j) {
398
+ Message* nested_message =
399
+ reflection->MutableRepeatedMessage (message, field, j);
400
+ Run (nested_message, max_depth);
401
+ }
402
+ } else if (reflection->HasField (*message, field)) {
403
+ Message* nested_message = reflection->MutableMessage (message, field);
404
+ Run (nested_message, max_depth);
405
+ }
406
+ }
407
+
408
+ auto range = post_processors_.equal_range (descriptor);
409
+ for (auto it = range.first ; it != range.second ; ++it)
410
+ it->second (message, (*random_)());
411
+ }
412
+
413
+ private:
414
+ bool keep_initialized_;
415
+ const PostProcessors& post_processors_;
416
+ RandomEngine* random_;
417
+ };
418
+
358
419
} // namespace
359
420
360
421
class FieldMutator {
@@ -479,47 +540,16 @@ void Mutator::Mutate(Message* message, size_t max_size_hint) {
479
540
static_cast <int >(max_size_hint) -
480
541
static_cast <int >(message->ByteSizeLong ()));
481
542
482
- InitializeAndTrim (message, kMaxInitializeDepth );
543
+ PostProcessing (keep_initialized_, post_processors_, &random_)
544
+ .Run (message, kMaxInitializeDepth );
483
545
assert (IsInitialized (*message));
484
-
485
- if (!post_processors_.empty ()) {
486
- ApplyPostProcessing (message);
487
- }
488
546
}
489
547
490
548
void Mutator::RegisterPostProcessor (const Descriptor* desc,
491
549
PostProcess callback) {
492
550
post_processors_.emplace (desc, callback);
493
551
}
494
552
495
- void Mutator::ApplyPostProcessing (Message* message) {
496
- const Descriptor* descriptor = message->GetDescriptor ();
497
-
498
- auto range = post_processors_.equal_range (descriptor);
499
- for (auto it = range.first ; it != range.second ; ++it)
500
- it->second (message, random_ ());
501
-
502
- // Now recursively apply custom mutators.
503
- const Reflection* reflection = message->GetReflection ();
504
- for (int i = 0 ; i < descriptor->field_count (); i++) {
505
- const FieldDescriptor* field = descriptor->field (i);
506
- if (field->cpp_type () != FieldDescriptor::CPPTYPE_MESSAGE) {
507
- continue ;
508
- }
509
- if (field->is_repeated ()) {
510
- const int field_size = reflection->FieldSize (*message, field);
511
- for (int j = 0 ; j < field_size; ++j) {
512
- Message* nested_message =
513
- reflection->MutableRepeatedMessage (message, field, j);
514
- ApplyPostProcessing (nested_message);
515
- }
516
- } else if (reflection->HasField (*message, field)) {
517
- Message* nested_message = reflection->MutableMessage (message, field);
518
- ApplyPostProcessing (nested_message);
519
- }
520
- }
521
- }
522
-
523
553
bool Mutator::MutateImpl (const Message& source, Message* message,
524
554
bool copy_clone_only, int size_increase_hint) {
525
555
if (size_increase_hint > 0 ) size_increase_hint /= 2 ;
@@ -578,49 +608,9 @@ void Mutator::CrossOver(const Message& message1, Message* message2,
578
608
MutateImpl (message1, message2, true , size_increase_hint) ||
579
609
MutateImpl (*message2, message2, true , size_increase_hint);
580
610
581
- InitializeAndTrim (message2, kMaxInitializeDepth );
611
+ PostProcessing (keep_initialized_, post_processors_, &random_)
612
+ .Run (message2, kMaxInitializeDepth );
582
613
assert (IsInitialized (*message2));
583
-
584
- if (!post_processors_.empty ()) {
585
- ApplyPostProcessing (message2);
586
- }
587
- }
588
-
589
- void Mutator::InitializeAndTrim (Message* message, int max_depth) {
590
- const Descriptor* descriptor = message->GetDescriptor ();
591
- const Reflection* reflection = message->GetReflection ();
592
- for (int i = 0 ; i < descriptor->field_count (); ++i) {
593
- const FieldDescriptor* field = descriptor->field (i);
594
- if (keep_initialized_ &&
595
- (field->is_required () || descriptor->options ().map_entry ()) &&
596
- !reflection->HasField (*message, field)) {
597
- CreateDefaultField ()(FieldInstance (message, field));
598
- }
599
-
600
- if (field->cpp_type () == FieldDescriptor::CPPTYPE_MESSAGE) {
601
- if (max_depth <= 0 && !field->is_required ()) {
602
- // Clear deep optional fields to avoid stack overflow.
603
- reflection->ClearField (message, field);
604
- if (field->is_repeated ())
605
- assert (!reflection->FieldSize (*message, field));
606
- else
607
- assert (!reflection->HasField (*message, field));
608
- continue ;
609
- }
610
-
611
- if (field->is_repeated ()) {
612
- const int field_size = reflection->FieldSize (*message, field);
613
- for (int j = 0 ; j < field_size; ++j) {
614
- Message* nested_message =
615
- reflection->MutableRepeatedMessage (message, field, j);
616
- InitializeAndTrim (nested_message, max_depth - 1 );
617
- }
618
- } else if (reflection->HasField (*message, field)) {
619
- Message* nested_message = reflection->MutableMessage (message, field);
620
- InitializeAndTrim (nested_message, max_depth - 1 );
621
- }
622
- }
623
- }
624
614
}
625
615
626
616
int32_t Mutator::MutateInt32 (int32_t value) { return FlipBit (value, &random_); }
0 commit comments