Skip to content

Commit 2ad36ee

Browse files
zoddicusDawn LUCI CQ
authored andcommitted
[tint][ir][fuzz] Prevent encoding a binary that won't decode
Adds plumbing to get errors out of IR binary encoding that are not ICEs. This is then used in the roundtrip fuzzer to reject inputs that won't decode correctly due to internal limits. Fixes: 375220551 Change-Id: I775c150f867124b3e30cc2161645134e88b2c625 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/212314 Commit-Queue: Ryan Harrison <[email protected]> Auto-Submit: Ryan Harrison <[email protected]> Reviewed-by: dan sinclair <[email protected]> Commit-Queue: dan sinclair <[email protected]>
1 parent 111d568 commit 2ad36ee

File tree

4 files changed

+46
-9
lines changed

4 files changed

+46
-9
lines changed

src/tint/cmd/fuzz/ir/as/main.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,11 @@ tint::Result<tint::cmd::fuzz::ir::pb::Root> GenerateFuzzCaseProto(const tint::Pr
199199
tint::cmd::fuzz::ir::pb::Root fuzz_pb;
200200
{
201201
auto ir_pb = tint::core::ir::binary::EncodeToProto(module.Get());
202-
fuzz_pb.set_allocated_module(ir_pb.release());
202+
if (ir_pb != tint::Success) {
203+
std::cerr << " Failed to encode IR to proto: " << ir_pb.Failure() << "\n";
204+
return tint::Failure();
205+
}
206+
fuzz_pb.set_allocated_module(ir_pb.Get().release());
203207
}
204208

205209
return std::move(fuzz_pb);

src/tint/lang/core/ir/binary/encode.cc

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
#include "src/tint/lang/core/type/storage_texture.h"
8484
#include "src/tint/lang/core/type/u32.h"
8585
#include "src/tint/lang/core/type/void.h"
86+
#include "src/tint/utils/constants/internal_limits.h"
8687
#include "src/tint/utils/macros/compiler.h"
8788
#include "src/tint/utils/rtti/switch.h"
8889

@@ -95,13 +96,16 @@ namespace {
9596
struct Encoder {
9697
const Module& mod_in_;
9798
pb::Module& mod_out_;
99+
98100
Hashmap<const core::ir::Function*, uint32_t, 32> functions_{};
99101
Hashmap<const core::ir::Block*, uint32_t, 32> blocks_{};
100102
Hashmap<const core::type::Type*, uint32_t, 32> types_{};
101103
Hashmap<const core::ir::Value*, uint32_t, 32> values_{};
102104
Hashmap<const core::constant::Value*, uint32_t, 32> constant_values_{};
103105

104-
void Encode() {
106+
diag::List diags_{};
107+
108+
Result<SuccessType> Encode() {
105109
// Encode all user-declared structures first. This is to ensure that the IR disassembly
106110
// (which prints structure types first) does not reorder after encoding and decoding.
107111
for (auto* ty : mod_in_.Types()) {
@@ -119,8 +123,16 @@ struct Encoder {
119123
PopulateFunction(fns_out[i], mod_in_.functions[i]);
120124
}
121125
mod_out_.set_root_block(Block(mod_in_.root_block));
126+
127+
if (diags_.ContainsErrors()) {
128+
return Failure{std::move(diags_)};
129+
}
130+
return Success;
122131
}
123132

133+
/// Adds a new error to the diagnostics and returns a reference to it
134+
diag::Diagnostic& Error() { return diags_.AddError(Source{}); }
135+
124136
////////////////////////////////////////////////////////////////////////////
125137
// Functions
126138
////////////////////////////////////////////////////////////////////////////
@@ -477,7 +489,13 @@ struct Encoder {
477489
array_out.set_stride(array_in->Stride());
478490
tint::Switch(
479491
array_in->Count(), //
480-
[&](const core::type::ConstantArrayCount* c) { array_out.set_count(c->value); },
492+
[&](const core::type::ConstantArrayCount* c) {
493+
array_out.set_count(c->value);
494+
if (c->value >= internal_limits::kMaxArrayElementCount) {
495+
Error() << "array count (" << c->value << ") must be less than "
496+
<< internal_limits::kMaxArrayElementCount;
497+
}
498+
},
481499
[&](const core::type::RuntimeArrayCount*) { array_out.set_count(0); },
482500
TINT_ICE_ON_NO_MATCH);
483501
}
@@ -647,6 +665,10 @@ struct Encoder {
647665
void ConstantValueSplat(pb::ConstantValueSplat& splat_out,
648666
const core::constant::Splat* splat_in) {
649667
splat_out.set_type(Type(splat_in->type));
668+
if (DAWN_UNLIKELY(splat_in->count > internal_limits::kMaxArrayConstructorElements)) {
669+
Error() << "array constructor has excessive number of elements (>"
670+
<< internal_limits::kMaxArrayConstructorElements << ")";
671+
}
650672
splat_out.set_elements(ConstantValue(splat_in->el));
651673
splat_out.set_count(static_cast<uint32_t>(splat_in->count));
652674
}
@@ -1220,23 +1242,29 @@ struct Encoder {
12201242

12211243
} // namespace
12221244

1223-
std::unique_ptr<pb::Module> EncodeToProto(const Module& mod_in) {
1245+
Result<std::unique_ptr<pb::Module>> EncodeToProto(const Module& mod_in) {
12241246
GOOGLE_PROTOBUF_VERIFY_VERSION;
12251247

12261248
pb::Module mod_out;
1227-
Encoder{mod_in, mod_out}.Encode();
1249+
auto res = Encoder{mod_in, mod_out}.Encode();
1250+
if (res != Success) {
1251+
return res.Failure();
1252+
}
12281253

12291254
return std::make_unique<pb::Module>(mod_out);
12301255
}
12311256

12321257
Result<Vector<std::byte, 0>> EncodeToBinary(const Module& mod_in) {
12331258
auto mod_out = EncodeToProto(mod_in);
1259+
if (mod_out != Success) {
1260+
return mod_out.Failure();
1261+
}
12341262

12351263
Vector<std::byte, 0> buffer;
1236-
size_t len = mod_out->ByteSizeLong();
1264+
size_t len = mod_out.Get()->ByteSizeLong();
12371265
buffer.Resize(len);
12381266
if (len > 0) {
1239-
if (!mod_out->SerializeToArray(&buffer[0], static_cast<int>(len))) {
1267+
if (!mod_out.Get()->SerializeToArray(&buffer[0], static_cast<int>(len))) {
12401268
return Failure{"failed to serialize protobuf"};
12411269
}
12421270
}

src/tint/lang/core/ir/binary/encode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Module;
4646
namespace tint::core::ir::binary {
4747

4848
// Encode the module into a proto representation.
49-
std::unique_ptr<pb::Module> EncodeToProto(const Module& module);
49+
Result<std::unique_ptr<pb::Module>> EncodeToProto(const Module& module);
5050

5151
// Encode the module into a binary representation.
5252
Result<Vector<std::byte, 0>> EncodeToBinary(const Module& module);

src/tint/lang/core/ir/binary/roundtrip_fuzz.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,19 @@
2929
#include "src/tint/lang/core/ir/binary/decode.h"
3030
#include "src/tint/lang/core/ir/binary/encode.h"
3131
#include "src/tint/lang/core/ir/disassembler.h"
32+
#include "src/tint/lang/core/ir/validator.h"
3233

3334
namespace tint::core::ir::binary {
3435
namespace {
3536

3637
void IRBinaryRoundtripFuzzer(core::ir::Module& module) {
3738
auto encoded = EncodeToBinary(module);
3839
if (encoded != Success) {
39-
TINT_ICE() << "Encode() failed\n" << encoded.Failure();
40+
// Failing to encode, not ICE'ing, indicates that an internal limit to the IR binary
41+
// encoding/decoding logic was hit. Due to differences between the AST and IR
42+
// implementations, there exist corner cases where these internal limits are hit for IR,
43+
// but not AST.
44+
return;
4045
}
4146

4247
auto decoded = Decode(encoded->Slice());

0 commit comments

Comments
 (0)