Skip to content

Commit d6139bf

Browse files
committed
fix fn pointers and global data
1 parent d5bbedf commit d6139bf

File tree

8 files changed

+136
-16
lines changed

8 files changed

+136
-16
lines changed

src/frontends/llvm/l2s_instr.c

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,18 @@ EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVM
452452
LLVMValueRef callee = LLVMGetCalledValue(instr);
453453
callee = remove_ptr_bitcasts(p, callee);
454454
assert(num_args + 1 == num_ops);
455-
String intrinsic = is_llvm_intrinsic(callee);
456-
if (!intrinsic)
457-
intrinsic = is_shady_intrinsic(callee);
455+
String intrinsic = NULL;
456+
if (LLVMIsAFunction(callee) || LLVMIsAConstant(callee)) {
457+
intrinsic = is_llvm_intrinsic(callee);
458+
if (!intrinsic)
459+
intrinsic = is_shady_intrinsic(callee);
460+
}
458461
if (intrinsic) {
459462
assert(LLVMIsAFunction(callee));
460463
if (strcmp(intrinsic, "llvm.dbg.declare") == 0) {
461464
const Node* target = convert_value(p, LLVMGetOperand(instr, 0));
465+
if (target->tag != Variable_TAG)
466+
return (EmittedInstr) { 0 };
462467
assert(target->tag == Variable_TAG);
463468
const Node* meta = convert_value(p, LLVMGetOperand(instr, 1));
464469
assert(meta->tag == RefDecl_TAG);

src/frontends/llvm/l2s_type.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) {
3636
LLVMGetParamTypes(t, param_types);
3737
LARRAY(const Type*, cparam_types, num_params);
3838
for (size_t i = 0; i < num_params; i++)
39-
cparam_types[i] = convert_type(p, param_types[i]);
39+
cparam_types[i] = qualified_type_helper(convert_type(p, param_types[i]), false);
4040
const Type* ret_type = convert_type(p, LLVMGetReturnType(t));
4141
if (LLVMGetTypeKind(LLVMGetReturnType(t)) == LLVMVoidTypeKind)
4242
ret_type = empty_multiple_return_type(a);

src/runtime/vulkan/vk_runtime_program.c

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "arena.h"
1010
#include "util.h"
11+
#include "type.h"
1112

1213
#include "../../shady/transform/memory_layout.h"
1314

@@ -42,6 +43,46 @@ VkDescriptorType as_to_descriptor_type(AddressSpace as) {
4243
}
4344
}
4445

46+
static void write_value(unsigned char* tgt, const Node* value) {
47+
IrArena* a = value->arena;
48+
switch (value->tag) {
49+
case IntLiteral_TAG: {
50+
switch (value->payload.int_literal.width) {
51+
case IntTy8: *((uint8_t*) tgt) = (uint8_t) (value->payload.int_literal.value & 0xFF); break;
52+
case IntTy16: *((uint16_t*) tgt) = (uint16_t) (value->payload.int_literal.value & 0xFFFF); break;
53+
case IntTy32: *((uint32_t*) tgt) = (uint32_t) (value->payload.int_literal.value & 0xFFFFFFFF); break;
54+
case IntTy64: *((uint64_t*) tgt) = (uint64_t) (value->payload.int_literal.value); break;
55+
}
56+
break;
57+
}
58+
case Composite_TAG: {
59+
Nodes values = value->payload.composite.contents;
60+
const Type* struct_t = value->payload.composite.type;
61+
struct_t = get_maybe_nominal_type_body(struct_t);
62+
63+
if (struct_t->tag == RecordType_TAG) {
64+
LARRAY(FieldLayout, fields, values.count);
65+
get_record_layout(a, struct_t, fields);
66+
for (size_t i = 0; i < values.count; i++) {
67+
// TypeMemLayout layout = get_mem_layout(value->arena, get_unqualified_type(element->type));
68+
write_value(tgt + fields->offset_in_bytes, values.nodes[i]);
69+
}
70+
} else if (struct_t->tag == ArrType_TAG) {
71+
for (size_t i = 0; i < values.count; i++) {
72+
TypeMemLayout layout = get_mem_layout(value->arena, get_unqualified_type(values.nodes[i]->type));
73+
write_value(tgt, values.nodes[i]);
74+
tgt += layout.size_in_bytes;
75+
}
76+
} else {
77+
assert(false);
78+
}
79+
break;
80+
}
81+
default:
82+
assert(false);
83+
}
84+
}
85+
4586
static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLayout layouts[]) {
4687
VkDescriptorSetLayoutCreateInfo layout_create_infos[MAX_DESCRIPTOR_SETS] = { 0 };
4788
Growy* bindings_lists[MAX_DESCRIPTOR_SETS] = { 0 };
@@ -78,6 +119,8 @@ static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLay
78119

79120
for (size_t j = 0; j < struct_t->payload.record_type.members.count; j++) {
80121
const Type* member_t = struct_t->payload.record_type.members.nodes[j];
122+
assert(member_t->tag == PtrType_TAG);
123+
member_t = get_pointee_type(program->arena, member_t);
81124
TypeMemLayout layout = get_mem_layout(program->specialized_module->arena, member_t);
82125

83126
ProgramResourceInfo* constant_res_info = arena_alloc(program->arena, sizeof(ProgramResourceInfo));
@@ -93,6 +136,15 @@ static bool extract_resources_layout(VkrSpecProgram* program, VkDescriptorSetLay
93136
res_info->size += sizeof(void*);
94137

95138
// TODO initial value
139+
Nodes annotations = get_declaration_annotations(decl);
140+
for (size_t k = 0; k < annotations.count; k++) {
141+
const Node* a = annotations.nodes[k];
142+
if ((strcmp(get_annotation_name(a), "InitialValue") == 0) && resolve_to_int_literal(first(get_annotation_values(a)))->value == j) {
143+
constant_res_info->default_data = calloc(1, layout.size_in_bytes);
144+
write_value(constant_res_info->default_data, get_annotation_values(a).nodes[1]);
145+
//printf("wowie");
146+
}
147+
}
96148
}
97149

98150
if (vkr_can_import_host_memory(program->device))
@@ -394,11 +446,13 @@ static bool prepare_resources(VkrSpecProgram* program) {
394446
resource->buffer = allocate_buffer_device(program->device, resource->size);
395447
}
396448

397-
// TODO: initial data!
398-
// if (!resource->host_owned)
399-
char* zeroes = calloc(1, resource->size);
400-
copy_to_buffer(resource->buffer, 0, zeroes, resource->size);
401-
free(zeroes);
449+
if (resource->default_data) {
450+
copy_to_buffer(resource->buffer, 0, resource->default_data, resource->size);
451+
} else {
452+
char* zeroes = calloc(1, resource->size);
453+
copy_to_buffer(resource->buffer, 0, zeroes, resource->size);
454+
free(zeroes);
455+
}
402456

403457
if (resource->parent) {
404458
char* dst = resource->parent->host_ptr;

src/shady/analysis/callgraph.c

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ static const Node* ignore_immediate_fn_addr(const Node* node) {
3838
return node;
3939
}
4040

41+
static CGNode* analyze_fn(CallGraph* graph, const Node* fn);
42+
4143
static void visit_callsite(CGVisitor* visitor, const Node* callee, const Node* instr) {
44+
assert(visitor->root);
4245
assert(callee->tag == Function_TAG);
4346
CGNode* target = analyze_fn(visitor->graph, callee);
4447
// Immediate recursion
@@ -55,10 +58,11 @@ static void visit_callsite(CGVisitor* visitor, const Node* callee, const Node* i
5558
}
5659

5760
static void search_for_callsites(CGVisitor* visitor, const Node* node) {
58-
assert(is_abstraction(visitor->abs));
61+
assert((visitor->abs && is_abstraction(visitor->abs)) || !visitor->root);
5962
switch (node->tag) {
6063
case Function_TAG: {
6164
assert(false);
65+
// analyze_fn(visitor->graph, node)->is_address_captured = true;
6266
break;
6367
}
6468
case BasicBlock_TAG:
@@ -74,12 +78,15 @@ static void search_for_callsites(CGVisitor* visitor, const Node* node) {
7478
break;
7579
}
7680
case Call_TAG: {
81+
assert(visitor->root && "calls can only occur in functions");
7782
const Node* callee = node->payload.call.callee;
7883
callee = ignore_immediate_fn_addr(callee);
7984
if (callee->tag == Function_TAG)
8085
visit_callsite(visitor, callee, node);
81-
else
86+
else {
87+
visitor->root->calls_indirect = true;
8288
visit_op(&visitor->visitor, NcValue, "callee", callee);
89+
}
8390
visit_ops(&visitor->visitor, NcValue, "args", node->payload.call.args);
8491
break;
8592
}
@@ -216,8 +223,29 @@ CallGraph* new_callgraph(Module* mod) {
216223

217224
Nodes decls = get_module_declarations(mod);
218225
for (size_t i = 0; i < decls.count; i++) {
219-
if (decls.nodes[i]->tag == Function_TAG) {
220-
analyze_fn(graph, decls.nodes[i]);
226+
const Node* decl = decls.nodes[i];
227+
if (decl->tag == Function_TAG) {
228+
analyze_fn(graph, decl);
229+
} else if (decl->tag == GlobalVariable_TAG && decl->payload.global_variable.init) {
230+
CGVisitor v = {
231+
.visitor = {
232+
.visit_node_fn = (VisitNodeFn) search_for_callsites
233+
},
234+
.graph = graph,
235+
.root = NULL,
236+
.abs = NULL,
237+
};
238+
search_for_callsites(&v, decl->payload.global_variable.init);
239+
} else if (decl->tag == Constant_TAG && decl->payload.constant.instruction) {
240+
CGVisitor v = {
241+
.visitor = {
242+
.visit_node_fn = (VisitNodeFn) search_for_callsites
243+
},
244+
.graph = graph,
245+
.root = NULL,
246+
.abs = NULL,
247+
};
248+
search_for_callsites(&v, decl->payload.constant.instruction);
221249
}
222250
}
223251

src/shady/analysis/callgraph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct CGNode_ {
2424
bool is_recursive;
2525
/// set to true if the address of this is captured by a FnAddr node that is not immediately consumed by a call
2626
bool is_address_captured;
27+
bool calls_indirect;
2728
};
2829

2930
typedef struct Callgraph_ {

src/shady/passes/lower_callf.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ static const Node* lower_callf_process(Context* ctx, const Node* old) {
2424
if (found) return found;
2525
IrArena* a = ctx->rewriter.dst_arena;
2626
Module* m = ctx->rewriter.dst_module;
27+
Rewriter* r = &ctx->rewriter;
2728

2829
if (old->tag == Function_TAG) {
2930
Context ctx2 = *ctx;
@@ -67,6 +68,19 @@ static const Node* lower_callf_process(Context* ctx, const Node* old) {
6768
return recreate_node_identity(&ctx->rewriter, old);
6869

6970
switch (old->tag) {
71+
case FnType_TAG: {
72+
Nodes param_types = rewrite_nodes(r, old->payload.fn_type.param_types);
73+
Nodes returned_types = rewrite_nodes(&ctx->rewriter, old->payload.fn_type.return_types);
74+
const Type* jp_type = qualified_type(a, (QualifiedType) {
75+
.type = join_point_type(a, (JoinPointType) { .yield_types = strip_qualifiers(a, returned_types) }),
76+
.is_uniform = false
77+
});
78+
param_types = append_nodes(a, param_types, jp_type);
79+
return fn_type(a, (FnType) {
80+
.param_types = param_types,
81+
.return_types = empty(a),
82+
});
83+
}
7084
case Return_TAG: {
7185
Nodes nargs = rewrite_nodes(&ctx->rewriter, old->payload.fn_ret.args);
7286

src/shady/passes/mark_leaf_functions.c

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,28 @@ static bool is_leaf_fn(Context* ctx, CGNode* fn_node) {
4747
info = find_value_dict(const Node*, FnInfo, ctx->fns, fn_node->fn);
4848
assert(info);
4949

50-
if (fn_node->is_address_captured || fn_node->is_recursive) {
50+
if (fn_node->is_address_captured || fn_node->is_recursive || fn_node->calls_indirect) {
5151
info->is_leaf = false;
5252
info->done = true;
53-
debugv_print("Function %s can't be a leaf function because %s.\n", get_abstraction_name(fn_node->fn), fn_node->is_address_captured ? "its address is captured" : "it is recursive" );
53+
debugv_print("Function %s can't be a leaf function because", get_abstraction_name(fn_node->fn));
54+
bool and = false;
55+
if (fn_node->is_address_captured) {
56+
debugv_print("its address is captured");
57+
and = true;
58+
}
59+
if (fn_node->is_recursive) {
60+
if (and)
61+
debugv_print(" and ");
62+
debugv_print("it is recursive");
63+
and = true;
64+
}
65+
if (fn_node->calls_indirect) {
66+
if (and)
67+
debugv_print(" and ");
68+
debugv_print("it makes indirect calls");
69+
and = true;
70+
}
71+
debugv_print(".\n");
5472
return false;
5573
}
5674

src/shady/passes/spirv_lift_globals_ssbo.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ Module* spirv_lift_globals_ssbo(SHADY_UNUSED const CompilerConfig* config, Modul
104104
if (odecl->tag != GlobalVariable_TAG || odecl->payload.global_variable.address_space != AsGlobal)
105105
continue;
106106
if (odecl->payload.global_variable.init)
107-
annotations = append_nodes(a, annotations, annotation_values(a, (AnnotationValues) {
107+
ctx.lifted_globals_decl->payload.global_variable.annotations = append_nodes(a, ctx.lifted_globals_decl->payload.global_variable.annotations, annotation_values(a, (AnnotationValues) {
108108
.name = "InitialValue",
109109
.values = mk_nodes(a, int32_literal(a, lifted_globals_count), rewrite_node(&ctx.rewriter, odecl->payload.global_variable.init))
110110
}));

0 commit comments

Comments
 (0)