Skip to content

Commit f7c52f2

Browse files
Jiawei-ShaoDawn LUCI CQ
authored andcommitted
[ir] Validate attribute subgroup_size in IR validator
This patch adds the attribute `subgroup_size` and all the related validation in Tint IR validator. - `subgroup_size` must be a constant or override expression - `subgroup_size` must be `i32` or `u32`. - `subgroup_size` must be greater than 0. - `subgroup_size` must be a power of 2. Bug: 463721943 Change-Id: I9de573e6253de3f66d0ee4809674d5f49881671a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/280415 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Shao, Jiawei <jiawei.shao@intel.com>
1 parent fce030b commit f7c52f2

File tree

15 files changed

+1282
-1007
lines changed

15 files changed

+1282
-1007
lines changed

src/tint/cmd/bench/enums_core_bench.cc

Lines changed: 1004 additions & 997 deletions
Large diffs are not rendered by default.

src/tint/lang/core/core.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ enum attribute {
249249

250250
// chromium_internal_input_attachments
251251
input_attachment_index
252+
253+
// chromium_experimental_subgroup_size_control
254+
subgroup_size
252255
}
253256

254257
// These are parameter usages which show up in other def files but not in core.def.

src/tint/lang/core/enums.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,9 @@ Attribute ParseAttribute(std::string_view str) {
11631163
if (str == "size") {
11641164
return Attribute::kSize;
11651165
}
1166+
if (str == "subgroup_size") {
1167+
return Attribute::kSubgroupSize;
1168+
}
11661169
if (str == "vertex") {
11671170
return Attribute::kVertex;
11681171
}
@@ -1207,6 +1210,8 @@ std::string_view ToString(Attribute value) {
12071210
return "must_use";
12081211
case Attribute::kSize:
12091212
return "size";
1213+
case Attribute::kSubgroupSize:
1214+
return "subgroup_size";
12101215
case Attribute::kVertex:
12111216
return "vertex";
12121217
case Attribute::kWorkgroupSize:

src/tint/lang/core/enums.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ enum class Attribute : uint8_t {
687687
kLocation,
688688
kMustUse,
689689
kSize,
690+
kSubgroupSize,
690691
kVertex,
691692
kWorkgroupSize,
692693
};
@@ -726,6 +727,7 @@ constexpr std::string_view kAttributeStrings[] = {
726727
"location",
727728
"must_use",
728729
"size",
730+
"subgroup_size",
729731
"vertex",
730732
"workgroup_size",
731733
};

src/tint/lang/core/enums_test.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,7 @@ static constexpr AttributeCase kValidAttributeCases[] = {
10831083
{"location", Attribute::kLocation},
10841084
{"must_use", Attribute::kMustUse},
10851085
{"size", Attribute::kSize},
1086+
{"subgroup_size", Attribute::kSubgroupSize},
10861087
{"vertex", Attribute::kVertex},
10871088
{"workgroup_size", Attribute::kWorkgroupSize},
10881089
};
@@ -1136,12 +1137,15 @@ static constexpr AttributeCase kInvalidAttributeCases[] = {
11361137
{"dzIp", Attribute::kUndefined},
11371138
{"ize", Attribute::kUndefined},
11381139
{"LN", Attribute::kUndefined},
1139-
{"r", Attribute::kUndefined},
1140-
{"vxxGteqqR", Attribute::kUndefined},
1141-
{"GGerteS", Attribute::kUndefined},
1142-
{"oqkccr8up_size", Attribute::kUndefined},
1143-
{"workgroup_sze", Attribute::kUndefined},
1144-
{"woppkgroup_sie", Attribute::kUndefined},
1140+
{"bgoupsiz", Attribute::kUndefined},
1141+
{"sxxbgrup_RRGizqq", Attribute::kUndefined},
1142+
{"GGubgrSup_size", Attribute::kUndefined},
1143+
{"88ccte", Attribute::kUndefined},
1144+
{"vrtex", Attribute::kUndefined},
1145+
{"perttx", Attribute::kUndefined},
1146+
{"workgr00up_siqFF55", Attribute::kUndefined},
1147+
{"workgroupsize", Attribute::kUndefined},
1148+
{"wokgroup_sie", Attribute::kUndefined},
11451149
};
11461150

11471151
using AttributeParseTest = testing::TestWithParam<AttributeCase>;

src/tint/lang/core/ir/binary/decode.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ struct Decoder {
283283
Value(wg_size_in.z()));
284284
}
285285

286+
if (fn_in.has_subgroup_size()) {
287+
uint32_t subgroup_size_in = fn_in.subgroup_size();
288+
fn_out->SetSubgroupSize(Value(subgroup_size_in));
289+
}
290+
286291
Vector<FunctionParam*, 8> params_out;
287292
for (auto param_in : fn_in.parameters()) {
288293
auto* param_out = ValueAs<FunctionParam>(param_in);

src/tint/lang/core/ir/binary/encode.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ struct Encoder {
150150
wg_size_out.set_y(Value((*wg_size_in)[1]));
151151
wg_size_out.set_z(Value((*wg_size_in)[2]));
152152
}
153+
if (auto subgroup_size_in = fn_in->SubgroupSize()) {
154+
fn_out->set_subgroup_size(Value(*subgroup_size_in));
155+
}
153156
for (auto* param_in : fn_in->Params()) {
154157
fn_out->add_parameters(Value(param_in));
155158
}

src/tint/lang/core/ir/disassembler.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,12 @@ void Disassembler::EmitFunction(const Function* func) {
343343
EmitValue(arr[2]);
344344
out_ << ")";
345345
}
346+
if (func->SubgroupSize()) {
347+
auto subgroup_size = func->SubgroupSize().value();
348+
out_ << " " << StyleAttribute("@subgroup_size") << "(";
349+
EmitValue(subgroup_size);
350+
out_ << ")";
351+
}
346352

347353
out_ << " " << StyleKeyword("func") << "(";
348354

src/tint/lang/core/ir/function.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,26 @@ class Function : public Castable<Function, Value> {
134134
}};
135135
}
136136

137+
/// Sets the subgroup size
138+
/// @param size the new subgroup size
139+
void SetSubgroupSize(Value* size) { subgroup_size_ = size; }
140+
141+
/// @returns the subgroup size information
142+
std::optional<Value*> SubgroupSize() const { return subgroup_size_; }
143+
144+
/// @returns the subgroup size information as `uint32_t` values. Note, this requires the value
145+
/// to be constant.
146+
std::optional<uint32_t> SubgroupSizeAsConst() const {
147+
if (!subgroup_size_.has_value()) {
148+
return std::nullopt;
149+
}
150+
151+
auto* value = subgroup_size_.value()->As<core::ir::Constant>();
152+
TINT_ASSERT(value);
153+
154+
return value->Value()->ValueAs<uint32_t>();
155+
}
156+
137157
/// @param type the return type for the function
138158
void SetReturnType(const core::type::Type* type) { return_.type = type; }
139159

@@ -224,6 +244,7 @@ class Function : public Castable<Function, Value> {
224244
private:
225245
PipelineStage pipeline_stage_ = PipelineStage::kUndefined;
226246
std::optional<std::array<Value*, 3>> workgroup_size_;
247+
std::optional<Value*> subgroup_size_;
227248

228249
struct {
229250
const core::type::Type* type = nullptr;

src/tint/lang/core/ir/validator.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,10 @@ class Validator {
12051205
/// @param func the function to validate
12061206
void CheckWorkgroupSize(const Function* func);
12071207

1208+
/// Validates the subgroup_size attribute for a given function
1209+
/// @param func the function to validate
1210+
void CheckSubgroupSize(const Function* func);
1211+
12081212
/// Validates the specific function as a vertex entry point
12091213
/// @param ep the function to validate
12101214
void CheckPositionPresentForVertexOutput(const Function* ep);
@@ -2700,6 +2704,8 @@ void Validator::CheckFunction(const Function* func) {
27002704

27012705
CheckWorkgroupSize(func);
27022706

2707+
CheckSubgroupSize(func);
2708+
27032709
if (func->Stage() == Function::PipelineStage::kCompute) {
27042710
if (DAWN_UNLIKELY(func->ReturnType() && !func->ReturnType()->Is<core::type::Void>())) {
27052711
AddError(func) << "compute entry point must not have a return type, found "
@@ -3099,6 +3105,69 @@ void Validator::CheckWorkgroupSize(const Function* func) {
30993105
}
31003106
}
31013107

3108+
void Validator::CheckSubgroupSize(const Function* func) {
3109+
// @subgroup_size is optional
3110+
if (!func->SubgroupSize().has_value()) {
3111+
return;
3112+
}
3113+
3114+
if (!func->IsCompute()) {
3115+
AddError(func) << "@subgroup_size only valid on compute entry point";
3116+
return;
3117+
}
3118+
3119+
auto subgroup_size = func->SubgroupSize().value();
3120+
if (!subgroup_size->Type()) {
3121+
AddError(func) << "a @subgroup_size param is missing a type";
3122+
return;
3123+
}
3124+
3125+
auto* ty = subgroup_size->Type();
3126+
if (!ty->IsAnyOf<core::type::I32, core::type::U32>()) {
3127+
AddError(func) << "@subgroup_size param must be an 'i32' or 'u32', received " << NameOf(ty);
3128+
return;
3129+
}
3130+
3131+
if (auto* c = subgroup_size->As<ir::Constant>()) {
3132+
int64_t value = c->Value()->ValueAs<int64_t>();
3133+
if (value <= 0) {
3134+
AddError(func) << "@subgroup_size param must be greater than 0";
3135+
return;
3136+
}
3137+
3138+
if (!IsPowerOfTwo<int64_t>(value)) {
3139+
AddError(func) << "@subgroup_size param must be a power of 2";
3140+
return;
3141+
}
3142+
3143+
return;
3144+
}
3145+
3146+
if (!capabilities_.Contains(Capability::kAllowOverrides)) {
3147+
AddError(func) << "@subgroup_size param is not a constant value, and IR capability "
3148+
"'kAllowOverrides' is not set";
3149+
return;
3150+
}
3151+
3152+
if (auto* r = subgroup_size->As<ir::InstructionResult>()) {
3153+
if (!r->Instruction()) {
3154+
AddError(func) << "instruction for @subgroup_size param is not defined";
3155+
return;
3156+
}
3157+
3158+
if (r->Instruction()->Block() != mod_.root_block) {
3159+
AddError(func) << "@subgroup_size param defined by non-module scope value";
3160+
return;
3161+
}
3162+
3163+
if (r->Instruction()->Is<core::ir::Override>()) {
3164+
return;
3165+
}
3166+
}
3167+
3168+
AddError(func) << "@subgroup_size must be an InstructionResult or a Constant";
3169+
}
3170+
31023171
void Validator::CheckPositionPresentForVertexOutput(const Function* ep) {
31033172
if (IsPositionPresent(ep->ReturnAttributes(), ep->ReturnType())) {
31043173
return;

0 commit comments

Comments
 (0)