Skip to content

Commit ab4bb8d

Browse files
committed
fix handling of SUBGROUP_SIZE
1 parent a98c01d commit ab4bb8d

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

src/shady/compile.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ CompilationResult run_compiler_passes(CompilerConfig* config, Module** pmod) {
113113
RUN_PASS(lift_indirect_targets)
114114
RUN_PASS(opt_mem2reg) // run because we can now weaken non-leaking allocas
115115

116-
if (config->specialization.execution_model != EmNone)
117-
RUN_PASS(specialize_execution_model)
116+
RUN_PASS(specialize_execution_model)
118117

119118
RUN_PASS(opt_stack)
120119

src/shady/passes/specialize_entry_point.c

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ static const Node* process(Context* ctx, const Node* node) {
5050
}
5151
case Constant_TAG: {
5252
Node* ncnst = (Node*) recreate_node_identity(&ctx->rewriter, node);
53-
if (strcmp(get_declaration_name(ncnst), "SUBGROUP_SIZE") == 0) {
54-
ncnst->payload.constant.instruction = quote_helper(a, singleton(uint32_literal(a, ctx->config->specialization.subgroup_size)));
55-
} else if (strcmp(get_declaration_name(ncnst), "SUBGROUPS_PER_WG") == 0) {
53+
if (strcmp(get_declaration_name(ncnst), "SUBGROUPS_PER_WG") == 0) {
5654
// SUBGROUPS_PER_WG = (NUMBER OF INVOCATIONS IN SUBGROUP / SUBGROUP SIZE)
5755
// Note: this computations assumes only full subgroups are launched, if subgroups can launch partially filled then this relationship does not hold.
5856
uint32_t wg_size[3];
@@ -87,9 +85,6 @@ static const Node* find_entry_point(Module* m, const CompilerConfig* config) {
8785
}
8886

8987
static void specialize_arena_config(const CompilerConfig* config, Module* src, ArenaConfig* target) {
90-
size_t subgroup_size = config->specialization.subgroup_size;
91-
assert(subgroup_size);
92-
9388
const Node* old_entry_point_decl = find_entry_point(src, config);
9489
if (old_entry_point_decl->tag != Function_TAG)
9590
error("%s is not a function", config->specialization.entry_point);

src/shady/passes/specialize_execution_model.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ static const Node* process(Context* ctx, const Node* node) {
2222

2323
IrArena* a = ctx->rewriter.dst_arena;
2424
switch (node->tag) {
25+
case Constant_TAG: {
26+
Node* ncnst = (Node*) recreate_node_identity(&ctx->rewriter, node);
27+
if (strcmp(get_declaration_name(ncnst), "SUBGROUP_SIZE") == 0) {
28+
ncnst->payload.constant.instruction = quote_helper(a, singleton(uint32_literal(a, ctx->config->specialization.subgroup_size)));
29+
}
30+
return ncnst;
31+
}
2532
default: break;
2633
}
2734
return recreate_node_identity(&ctx->rewriter, node);
@@ -44,6 +51,9 @@ Module* specialize_execution_model(const CompilerConfig* config, Module* src) {
4451
IrArena* a = new_ir_arena(aconfig);
4552
Module* dst = new_module(a, get_module_name(src));
4653

54+
size_t subgroup_size = config->specialization.subgroup_size;
55+
assert(subgroup_size);
56+
4757
Context ctx = {
4858
.rewriter = create_rewriter(src, dst, (RewriteNodeFn) process),
4959
.config = config,

src/shady/transform/internal_constants.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "shady/ir.h"
55

66
#define INTERNAL_CONSTANTS(X) \
7-
X(SUBGROUP_SIZE, uint32_type(arena), uint32_literal(arena, config->specialization.subgroup_size)) \
7+
X(SUBGROUP_SIZE, uint32_type(arena), uint32_literal(arena, 64)) \
88
X(SUBGROUPS_PER_WG, uint32_type(arena), uint32_literal(arena, 1)) \
99

1010
void generate_dummy_constants(const CompilerConfig* config, Module*);

0 commit comments

Comments
 (0)