17
17
#include < algorithm>
18
18
#include < bitset>
19
19
#include < map>
20
+ #include < memory>
20
21
#include < random>
21
22
#include < string>
22
23
#include < vector>
27
28
28
29
namespace protobuf_mutator {
29
30
31
+ using protobuf::Any;
30
32
using protobuf::Descriptor;
31
33
using protobuf::FieldDescriptor;
32
34
using protobuf::FileDescriptor;
@@ -360,15 +362,76 @@ class DataSourceSampler {
360
362
WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
361
363
};
362
364
365
+ using UnpackedAny =
366
+ std::unordered_map<const Message*, std::unique_ptr<Message>>;
367
+
368
+ const Descriptor* GetAnyTypeDescriptor (const Any& any) {
369
+ std::string type_name;
370
+ if (!Any::ParseAnyTypeUrl (any.type_url (), &type_name)) return nullptr ;
371
+ return any.descriptor ()->file ()->pool ()->FindMessageTypeByName (type_name);
372
+ }
373
+
374
+ std::unique_ptr<Message> UnpackAny (const Any& any) {
375
+ const Descriptor* desc = GetAnyTypeDescriptor (any);
376
+ if (!desc) return {};
377
+ std::unique_ptr<Message> message (
378
+ any.GetReflection ()->GetMessageFactory ()->GetPrototype (desc)->New ());
379
+ message->ParsePartialFromString (any.value ());
380
+ return message;
381
+ }
382
+
383
+ const Any* CastToAny (const Message* message) {
384
+ return Any::GetDescriptor () == message->GetDescriptor ()
385
+ ? static_cast <const Any*>(message)
386
+ : nullptr ;
387
+ }
388
+
389
+ Any* CastToAny (Message* message) {
390
+ return Any::GetDescriptor () == message->GetDescriptor ()
391
+ ? static_cast <Any*>(message)
392
+ : nullptr ;
393
+ }
394
+
395
+ std::unique_ptr<Message> UnpackIfAny (const Message& message) {
396
+ if (const Any* any = CastToAny (&message)) return UnpackAny (*any);
397
+ return {};
398
+ }
399
+
400
+ void UnpackAny (const Message& message, UnpackedAny* result) {
401
+ if (std::unique_ptr<Message> any = UnpackIfAny (message)) {
402
+ UnpackAny (*any, result);
403
+ result->emplace (&message, std::move (any));
404
+ return ;
405
+ }
406
+
407
+ const Descriptor* descriptor = message.GetDescriptor ();
408
+ const Reflection* reflection = message.GetReflection ();
409
+
410
+ for (int i = 0 ; i < descriptor->field_count (); ++i) {
411
+ const FieldDescriptor* field = descriptor->field (i);
412
+ if (field->cpp_type () == FieldDescriptor::CPPTYPE_MESSAGE) {
413
+ if (field->is_repeated ()) {
414
+ const int field_size = reflection->FieldSize (message, field);
415
+ for (int j = 0 ; j < field_size; ++j) {
416
+ UnpackAny (reflection->GetRepeatedMessage (message, field, j), result);
417
+ }
418
+ } else if (reflection->HasField (message, field)) {
419
+ UnpackAny (reflection->GetMessage (message, field), result);
420
+ }
421
+ }
422
+ }
423
+ }
424
+
363
425
class PostProcessing {
364
426
public:
365
427
using PostProcessors =
366
428
std::unordered_multimap<const Descriptor*, Mutator::PostProcess>;
367
429
368
430
PostProcessing (bool keep_initialized, const PostProcessors& post_processors,
369
- RandomEngine* random)
431
+ UnpackedAny& any, RandomEngine* random)
370
432
: keep_initialized_(keep_initialized),
371
433
post_processors_ (post_processors),
434
+ any_(any),
372
435
random_(random) {}
373
436
374
437
void Run (Message* message, int max_depth) {
@@ -410,6 +473,22 @@ class PostProcessing {
410
473
}
411
474
}
412
475
476
+ if (Any* any = CastToAny (message)) {
477
+ if (max_depth < 0 ) {
478
+ // Clear deep Any fields to avoid stack overflow.
479
+ any->Clear ();
480
+ } else {
481
+ auto It = any_.find (message);
482
+ if (It != any_.end ()) {
483
+ Run (It->second .get (), max_depth);
484
+ // assert(GetAnyTypeDescriptor(*any) == It->second->GetDescriptor());
485
+ // if (GetAnyTypeDescriptor(*any) != It->second->GetDescriptor()) {}
486
+ It->second ->SerializePartialToString (any->mutable_value ());
487
+ }
488
+ }
489
+ }
490
+
491
+ // Call user callback after message trimmed, initialized and packed.
413
492
auto range = post_processors_.equal_range (descriptor);
414
493
for (auto it = range.first ; it != range.second ; ++it)
415
494
it->second (message, (*random_)());
@@ -418,6 +497,7 @@ class PostProcessing {
418
497
private:
419
498
bool keep_initialized_;
420
499
const PostProcessors& post_processors_;
500
+ UnpackedAny& any_;
421
501
RandomEngine* random_;
422
502
};
423
503
@@ -543,30 +623,47 @@ struct CreateField : public FieldFunction<CreateField> {
543
623
void Mutator::Seed (uint32_t value) { random_.seed (value); }
544
624
545
625
void Mutator::Mutate (Message* message, size_t max_size_hint) {
626
+ UnpackedAny any;
627
+ UnpackAny (*message, &any);
628
+
546
629
Messages messages;
630
+ messages.reserve (any.size () + 1 );
547
631
messages.push_back (message);
632
+ for (const auto & kv : any) messages.push_back (kv.second .get ());
633
+
548
634
ConstMessages sources (messages.begin (), messages.end ());
549
635
MutateImpl (sources, messages, false ,
550
636
static_cast <int >(max_size_hint) -
551
637
static_cast <int >(message->ByteSizeLong ()));
552
638
553
- PostProcessing (keep_initialized_, post_processors_, &random_)
639
+ PostProcessing (keep_initialized_, post_processors_, any, &random_)
554
640
.Run (message, kMaxInitializeDepth );
555
641
assert (IsInitialized (*message));
556
642
}
557
643
558
644
void Mutator::CrossOver (const Message& message1, Message* message2,
559
645
size_t max_size_hint) {
646
+ UnpackedAny any;
647
+ UnpackAny (*message2, &any);
648
+
560
649
Messages messages;
650
+ messages.reserve (any.size () + 1 );
561
651
messages.push_back (message2);
652
+ for (auto & kv : any) messages.push_back (kv.second .get ());
653
+
654
+ UnpackAny (message1, &any);
655
+
562
656
ConstMessages sources;
657
+ sources.reserve (any.size () + 2 );
563
658
sources.push_back (&message1);
564
659
sources.push_back (message2);
565
- int size_increase_hint = static_cast <int >(max_size_hint) -
566
- static_cast <int >(message2->ByteSizeLong ());
567
- MutateImpl (sources, messages, true , size_increase_hint);
660
+ for (const auto & kv : any) sources.push_back (kv.second .get ());
661
+
662
+ MutateImpl (sources, messages, true ,
663
+ static_cast <int >(max_size_hint) -
664
+ static_cast <int >(message2->ByteSizeLong ()));
568
665
569
- PostProcessing (keep_initialized_, post_processors_, &random_)
666
+ PostProcessing (keep_initialized_, post_processors_, any, &random_)
570
667
.Run (message2, kMaxInitializeDepth );
571
668
assert (IsInitialized (*message2));
572
669
}
0 commit comments