Skip to content

Commit 10035b8

Browse files
committed
cli: added --stack-size option
1 parent 5cea342 commit 10035b8

File tree

4 files changed

+38
-24
lines changed

4 files changed

+38
-24
lines changed

src/driver/cli.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,14 @@ void cli_parse_compiler_config_args(CompilerConfig* config, int* pargc, char** a
121121
argv[i] = NULL;
122122
i++;
123123
if (i == argc)
124-
error("Missing subgroup size name");
124+
error("Missing subgroup size");
125125
config->specialization.subgroup_size = atoi(argv[i]);
126+
} else if (strcmp(argv[i], "--stack-size") == 0) {
127+
argv[i] = NULL;
128+
i++;
129+
if (i == argc)
130+
error("Missing stack size");
131+
config->per_thread_stack_size = atoi(argv[i]);
126132
} else if (strcmp(argv[i], "--execution-model") == 0) {
127133
argv[i] = NULL;
128134
i++;

src/shady/passes/lower_alloca.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ static const Node* process(Context* ctx, const Node* node) {
8585
case Function_TAG: {
8686
Node* fun = recreate_decl_header_identity(&ctx->rewriter, node);
8787
Context ctx2 = *ctx;
88-
ctx2.disable_lowering = lookup_annotation_with_string_payload(node, "DisablePass", "setup_stack_frames");
88+
ctx2.disable_lowering = lookup_annotation_with_string_payload(node, "DisablePass", "setup_stack_frames") || ctx->config->per_thread_stack_size == 0;
8989
ctx2.prepared_offsets = new_dict(const Node*, StackSlot, (HashFn) hash_node, (CmpFn) compare_node);
9090

9191
BodyBuilder* bb = begin_body(a);

src/shady/passes/lower_stack.c

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,20 @@ static const Node* process_let(Context* ctx, const Node* node) {
101101
const PrimOp* oprim_op = &old_instruction->payload.prim_op;
102102
switch (oprim_op->op) {
103103
case get_stack_pointer_op: {
104+
assert(ctx->stack);
104105
BodyBuilder* bb = begin_body(a);
105106
const Node* sp = gen_load(bb, ctx->stack_pointer);
106107
return finish_body(bb, let(a, quote_helper(a, singleton(sp)), tail));
107108
}
108109
case set_stack_pointer_op: {
110+
assert(ctx->stack);
109111
BodyBuilder* bb = begin_body(a);
110112
const Node* val = rewrite_node(&ctx->rewriter, oprim_op->operands.nodes[0]);
111113
gen_store(bb, ctx->stack_pointer, val);
112114
return finish_body(bb, let(a, quote_helper(a, empty(a)), tail));
113115
}
114116
case get_stack_base_op: {
117+
assert(ctx->stack);
115118
BodyBuilder* bb = begin_body(a);
116119
const Node* stack_pointer = ctx->stack_pointer;
117120
const Node* stack_size = gen_load(bb, stack_pointer);
@@ -126,6 +129,7 @@ static const Node* process_let(Context* ctx, const Node* node) {
126129
}
127130
case push_stack_op:
128131
case pop_stack_op: {
132+
assert(ctx->stack);
129133
BodyBuilder* bb = begin_body(a);
130134
const Type* element_type = rewrite_node(&ctx->rewriter, first(oprim_op->type_arguments));
131135

@@ -161,8 +165,10 @@ static const Node* process_node(Context* ctx, const Node* old) {
161165
// Make sure to zero-init the stack pointers
162166
// TODO isn't this redundant with thoose things having an initial value already ?
163167
// is this an old forgotten workaround ?
164-
const Node* stack_pointer = ctx->stack_pointer;
165-
gen_store(bb, stack_pointer, uint32_literal(a, 0));
168+
if (ctx->stack) {
169+
const Node* stack_pointer = ctx->stack_pointer;
170+
gen_store(bb, stack_pointer, uint32_literal(a, 0));
171+
}
166172
new->payload.fun.body = finish_body(bb, rewrite_node(&ctx->rewriter, old->payload.fun.body));
167173
return new;
168174
}
@@ -181,34 +187,36 @@ Module* lower_stack(SHADY_UNUSED const CompilerConfig* config, Module* src) {
181187
IrArena* a = new_ir_arena(aconfig);
182188
Module* dst = new_module(a, get_module_name(src));
183189

184-
const Type* stack_base_element = uint8_type(a);
185-
const Type* stack_arr_type = arr_type(a, (ArrType) {
186-
.element_type = stack_base_element,
187-
.size = uint32_literal(a, config->per_thread_stack_size),
188-
});
189-
const Type* stack_counter_t = uint32_type(a);
190-
191-
Nodes annotations = mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }));
192-
193-
// Arrays for the stacks
194-
Node* stack_decl = global_var(dst, annotations, stack_arr_type, "stack", AsPrivate);
195-
196-
// Pointers into those arrays
197-
Node* stack_ptr_decl = global_var(dst, append_nodes(a, annotations, annotation(a, (Annotation) { .name = "Logical" })), stack_counter_t, "stack_ptr", AsPrivate);
198-
stack_ptr_decl->payload.global_variable.init = uint32_literal(a, 0);
199-
200190
Context ctx = {
201191
.rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node),
202192

203193
.config = config,
204194

205195
.push = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node),
206196
.pop = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node),
207-
208-
.stack = ref_decl_helper(a, stack_decl),
209-
.stack_pointer = ref_decl_helper(a, stack_ptr_decl),
210197
};
211198

199+
if (config->per_thread_stack_size > 0) {
200+
const Type* stack_base_element = uint8_type(a);
201+
const Type* stack_arr_type = arr_type(a, (ArrType) {
202+
.element_type = stack_base_element,
203+
.size = uint32_literal(a, config->per_thread_stack_size),
204+
});
205+
const Type* stack_counter_t = uint32_type(a);
206+
207+
Nodes annotations = mk_nodes(a, annotation(a, (Annotation) { .name = "Generated" }));
208+
209+
// Arrays for the stacks
210+
Node* stack_decl = global_var(dst, annotations, stack_arr_type, "stack", AsPrivate);
211+
212+
// Pointers into those arrays
213+
Node* stack_ptr_decl = global_var(dst, append_nodes(a, annotations, annotation(a, (Annotation) { .name = "Logical" })), stack_counter_t, "stack_ptr", AsPrivate);
214+
stack_ptr_decl->payload.global_variable.init = uint32_literal(a, 0);
215+
216+
ctx.stack = ref_decl_helper(a, stack_decl);
217+
ctx.stack_pointer = ref_decl_helper(a, stack_ptr_decl);
218+
}
219+
212220
rewrite_module(&ctx.rewriter);
213221
destroy_rewriter(&ctx.rewriter);
214222

src/shady/passes/setup_stack_frames.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ static const Node* process(Context* ctx, const Node* node) {
3131
case Function_TAG: {
3232
Node* fun = recreate_decl_header_identity(r, node);
3333
Context ctx2 = *ctx;
34-
ctx2.disable_lowering = lookup_annotation_with_string_payload(node, "DisablePass", "setup_stack_frames");
34+
ctx2.disable_lowering = lookup_annotation_with_string_payload(node, "DisablePass", "setup_stack_frames") || ctx->config->per_thread_stack_size == 0;
3535

3636
BodyBuilder* bb = begin_body(a);
3737
if (!ctx2.disable_lowering) {

0 commit comments

Comments
 (0)