@@ -101,17 +101,20 @@ static const Node* process_let(Context* ctx, const Node* node) {
101
101
const PrimOp * oprim_op = & old_instruction -> payload .prim_op ;
102
102
switch (oprim_op -> op ) {
103
103
case get_stack_pointer_op : {
104
+ assert (ctx -> stack );
104
105
BodyBuilder * bb = begin_body (a );
105
106
const Node * sp = gen_load (bb , ctx -> stack_pointer );
106
107
return finish_body (bb , let (a , quote_helper (a , singleton (sp )), tail ));
107
108
}
108
109
case set_stack_pointer_op : {
110
+ assert (ctx -> stack );
109
111
BodyBuilder * bb = begin_body (a );
110
112
const Node * val = rewrite_node (& ctx -> rewriter , oprim_op -> operands .nodes [0 ]);
111
113
gen_store (bb , ctx -> stack_pointer , val );
112
114
return finish_body (bb , let (a , quote_helper (a , empty (a )), tail ));
113
115
}
114
116
case get_stack_base_op : {
117
+ assert (ctx -> stack );
115
118
BodyBuilder * bb = begin_body (a );
116
119
const Node * stack_pointer = ctx -> stack_pointer ;
117
120
const Node * stack_size = gen_load (bb , stack_pointer );
@@ -126,6 +129,7 @@ static const Node* process_let(Context* ctx, const Node* node) {
126
129
}
127
130
case push_stack_op :
128
131
case pop_stack_op : {
132
+ assert (ctx -> stack );
129
133
BodyBuilder * bb = begin_body (a );
130
134
const Type * element_type = rewrite_node (& ctx -> rewriter , first (oprim_op -> type_arguments ));
131
135
@@ -161,8 +165,10 @@ static const Node* process_node(Context* ctx, const Node* old) {
161
165
// Make sure to zero-init the stack pointers
162
166
// TODO isn't this redundant with thoose things having an initial value already ?
163
167
// is this an old forgotten workaround ?
164
- const Node * stack_pointer = ctx -> stack_pointer ;
165
- gen_store (bb , stack_pointer , uint32_literal (a , 0 ));
168
+ if (ctx -> stack ) {
169
+ const Node * stack_pointer = ctx -> stack_pointer ;
170
+ gen_store (bb , stack_pointer , uint32_literal (a , 0 ));
171
+ }
166
172
new -> payload .fun .body = finish_body (bb , rewrite_node (& ctx -> rewriter , old -> payload .fun .body ));
167
173
return new ;
168
174
}
@@ -181,34 +187,36 @@ Module* lower_stack(SHADY_UNUSED const CompilerConfig* config, Module* src) {
181
187
IrArena * a = new_ir_arena (aconfig );
182
188
Module * dst = new_module (a , get_module_name (src ));
183
189
184
- const Type * stack_base_element = uint8_type (a );
185
- const Type * stack_arr_type = arr_type (a , (ArrType ) {
186
- .element_type = stack_base_element ,
187
- .size = uint32_literal (a , config -> per_thread_stack_size ),
188
- });
189
- const Type * stack_counter_t = uint32_type (a );
190
-
191
- Nodes annotations = mk_nodes (a , annotation (a , (Annotation ) { .name = "Generated" }));
192
-
193
- // Arrays for the stacks
194
- Node * stack_decl = global_var (dst , annotations , stack_arr_type , "stack" , AsPrivate );
195
-
196
- // Pointers into those arrays
197
- Node * stack_ptr_decl = global_var (dst , append_nodes (a , annotations , annotation (a , (Annotation ) { .name = "Logical" })), stack_counter_t , "stack_ptr" , AsPrivate );
198
- stack_ptr_decl -> payload .global_variable .init = uint32_literal (a , 0 );
199
-
200
190
Context ctx = {
201
191
.rewriter = create_rewriter (src , dst , (RewriteNodeFn ) process_node ),
202
192
203
193
.config = config ,
204
194
205
195
.push = new_dict (const Node * , Node * , (HashFn ) hash_node , (CmpFn ) compare_node ),
206
196
.pop = new_dict (const Node * , Node * , (HashFn ) hash_node , (CmpFn ) compare_node ),
207
-
208
- .stack = ref_decl_helper (a , stack_decl ),
209
- .stack_pointer = ref_decl_helper (a , stack_ptr_decl ),
210
197
};
211
198
199
+ if (config -> per_thread_stack_size > 0 ) {
200
+ const Type * stack_base_element = uint8_type (a );
201
+ const Type * stack_arr_type = arr_type (a , (ArrType ) {
202
+ .element_type = stack_base_element ,
203
+ .size = uint32_literal (a , config -> per_thread_stack_size ),
204
+ });
205
+ const Type * stack_counter_t = uint32_type (a );
206
+
207
+ Nodes annotations = mk_nodes (a , annotation (a , (Annotation ) { .name = "Generated" }));
208
+
209
+ // Arrays for the stacks
210
+ Node * stack_decl = global_var (dst , annotations , stack_arr_type , "stack" , AsPrivate );
211
+
212
+ // Pointers into those arrays
213
+ Node * stack_ptr_decl = global_var (dst , append_nodes (a , annotations , annotation (a , (Annotation ) { .name = "Logical" })), stack_counter_t , "stack_ptr" , AsPrivate );
214
+ stack_ptr_decl -> payload .global_variable .init = uint32_literal (a , 0 );
215
+
216
+ ctx .stack = ref_decl_helper (a , stack_decl );
217
+ ctx .stack_pointer = ref_decl_helper (a , stack_ptr_decl );
218
+ }
219
+
212
220
rewrite_module (& ctx .rewriter );
213
221
destroy_rewriter (& ctx .rewriter );
214
222
0 commit comments