Skip to content

Commit 6617053

Browse files
committed
remove support for untyped ptr in the IR: instead do it in l2s
1 parent 90127fa commit 6617053

File tree

18 files changed

+179
-188
lines changed

18 files changed

+179
-188
lines changed

include/shady/ir.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ typedef struct {
6969
bool check_op_classes;
7070
bool check_types;
7171
bool allow_fold;
72-
bool untyped_ptrs;
7372
bool validate_builtin_types; // do @Builtins variables need to match their type in builtins.h ?
7473
bool is_simt;
7574

src/frontends/llvm/l2s.c

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
#include "l2s_private.h"
2-
#include "shady/ir_private.h"
2+
3+
#include "ir_private.h"
4+
#include "type.h"
5+
#include "analysis/verify.h"
36

47
#include "log.h"
58
#include "dict.h"
69
#include "list.h"
710
#include "util.h"
11+
#include "portability.h"
812

913
#include "llvm-c/IRReader.h"
10-
#include "portability.h"
1114

1215
#include <assert.h>
1316
#include <string.h>
@@ -80,7 +83,7 @@ static TodoBB prepare_bb(Parser* p, Node* fn, LLVMBasicBlockRef bb) {
8083
while (instr) {
8184
switch (LLVMGetInstructionOpcode(instr)) {
8285
case LLVMPHI: {
83-
const Node* nparam = var(a, convert_type(p, LLVMTypeOf(instr)), "phi");
86+
const Node* nparam = var(a, qualified_type_helper(convert_type(p, LLVMTypeOf(instr)), false), "phi");
8487
insert_dict(LLVMValueRef, const Node*, p->map, instr, nparam);
8588
append_list(LLVMValueRef, phis, instr);
8689
params = append_nodes(a, params, nparam);
@@ -128,7 +131,7 @@ const Node* convert_function(Parser* p, LLVMValueRef fn) {
128131
for (LLVMValueRef oparam = LLVMGetFirstParam(fn); oparam; oparam = LLVMGetNextParam(oparam)) {
129132
LLVMTypeRef ot = LLVMTypeOf(oparam);
130133
const Type* t = convert_type(p, ot);
131-
const Node* param = var(a, t, LLVMGetValueName(oparam));
134+
const Node* param = var(a, qualified_type_helper(t, false), LLVMGetValueName(oparam));
132135
insert_dict(LLVMValueRef, const Node*, p->map, oparam, param);
133136
params = append_nodes(a, params, param);
134137
if (oparam == LLVMGetLastParam(fn))
@@ -217,6 +220,13 @@ const Node* convert_global(Parser* p, LLVMValueRef global) {
217220
decl = global_var(p->dst, empty(a), type, name, as);
218221
if (value && as != AsUniformConstant)
219222
decl->payload.global_variable.init = convert_value(p, value);
223+
224+
if (UNTYPED_POINTERS) {
225+
Node* untyped_wrapper = constant(p->dst, empty(a), ptr_t, format_string_interned(a, "%s_untyped", name));
226+
untyped_wrapper->payload.constant.instruction = quote_helper(a, singleton(ref_decl_helper(a, decl)));
227+
untyped_wrapper->payload.constant.instruction = prim_op_helper(a, reinterpret_op, singleton(ptr_t), singleton(ref_decl_helper(a, decl)));
228+
decl = untyped_wrapper;
229+
}
220230
} else {
221231
const Type* type = convert_type(p, LLVMTypeOf(global));
222232
decl = constant(p->dst, empty(a), type, name);
@@ -241,9 +251,6 @@ bool parse_llvm_into_shady(const CompilerConfig* config, size_t len, const char*
241251
error_die();
242252
}
243253
info_print("LLVM IR parsed successfully\n");
244-
#if UNTYPED_POINTERS
245-
get_module_arena(dst)->config.untyped_ptrs = true; // tolerate untyped ptrs...
246-
#endif
247254

248255
ArenaConfig aconfig = default_arena_config();
249256
aconfig.check_types = false;
@@ -278,6 +285,7 @@ bool parse_llvm_into_shady(const CompilerConfig* config, size_t len, const char*
278285
break;
279286
global = LLVMGetNextGlobal(global);
280287
}
288+
log_module(DEBUG, config, dirty);
281289

282290
aconfig.check_types = true;
283291
aconfig.allow_fold = true;

src/frontends/llvm/l2s_instr.c

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ static const Node* convert_jump(Parser* p, Node* fn, Node* fn_or_bb, LLVMBasicBl
7676
return jump_helper(a, dst_bb, nodes(a, params_count, params));
7777
}
7878

79+
static const Type* type_untyped_ptr(const Type* untyped_ptr_t, const Type* element_type) {
80+
IrArena* a = untyped_ptr_t->arena;
81+
assert(element_type);
82+
assert(untyped_ptr_t->tag == PtrType_TAG);
83+
assert(!untyped_ptr_t->payload.ptr_type.is_reference);
84+
const Type* typed_ptr_t = ptr_type(a, (PtrType) { .pointed_type = element_type, .address_space = untyped_ptr_t->payload.ptr_type.address_space });
85+
return typed_ptr_t;
86+
}
87+
7988
/// instr may be an instruction or a constantexpr
8089
EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVMValueRef instr) {
8190
Node* fn = fn_or_bb;
@@ -223,7 +232,7 @@ EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVM
223232
r = first(bind_instruction_explicit_result_types(b, prim_op_helper(a, alloca_op, singleton(allocated_t), empty(a)), singleton(allocated_ptr_t), NULL));
224233
if (UNTYPED_POINTERS) {
225234
const Type* untyped_ptr_t = ptr_type(a, (PtrType) { .pointed_type = unit_type(a), .address_space = AsPrivate });
226-
r = first(bind_instruction_outputs_count(b, prim_op_helper(a, reinterpret_op, singleton(untyped_ptr_t), singleton(r)), 1, NULL));
235+
r = first(bind_instruction_explicit_result_types(b, prim_op_helper(a, reinterpret_op, singleton(untyped_ptr_t), singleton(r)), singleton(untyped_ptr_t), NULL));
227236
}
228237
r = prim_op_helper(a, convert_op, singleton(t), singleton(r));
229238
break;
@@ -232,19 +241,49 @@ EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVM
232241
Nodes ops = convert_operands(p, num_ops, instr);
233242
assert(ops.count == 1);
234243
const Node* ptr = first(ops);
235-
r = prim_op_helper(a, load_op, singleton(t), singleton(ptr));
244+
if (UNTYPED_POINTERS) {
245+
const Type* element_t = t;
246+
const Type* untyped_ptr_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0)));
247+
const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t);
248+
ptr = first(bind_instruction_explicit_result_types(b, prim_op_helper(a, reinterpret_op, singleton(typed_ptr), singleton(ptr)), singleton(typed_ptr), NULL));
249+
}
250+
r = prim_op_helper(a, load_op, empty(a), singleton(ptr));
236251
break;
237252
}
238253
case LLVMStore: {
239254
num_results = 0;
240255
Nodes ops = convert_operands(p, num_ops, instr);
241256
assert(ops.count == 2);
242-
r = prim_op_helper(a, store_op, UNTYPED_POINTERS ? singleton(convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0)))) : empty(a), mk_nodes(a, ops.nodes[1], ops.nodes[0]));
257+
const Node* ptr = ops.nodes[1];
258+
if (UNTYPED_POINTERS) {
259+
const Type* element_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0)));
260+
const Type* untyped_ptr_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 1)));
261+
const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t);
262+
ptr = first(bind_instruction_explicit_result_types(b, prim_op_helper(a, reinterpret_op, singleton(typed_ptr), singleton(ptr)), singleton(typed_ptr), NULL));
263+
}
264+
r = prim_op_helper(a, store_op, empty(a), mk_nodes(a, ptr, ops.nodes[0]));
243265
break;
244266
}
245267
case LLVMGetElementPtr: {
246268
Nodes ops = convert_operands(p, num_ops, instr);
247-
r = prim_op_helper(a, lea_op, UNTYPED_POINTERS ? singleton(convert_type(p, LLVMGetGEPSourceElementType(instr))) : empty(a), ops);
269+
const Node* ptr = first(ops);
270+
if (UNTYPED_POINTERS) {
271+
const Type* element_t = convert_type(p, LLVMGetGEPSourceElementType(instr));
272+
const Type* untyped_ptr_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0)));
273+
const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t);
274+
ptr = first(bind_instruction_explicit_result_types(b, prim_op_helper(a, reinterpret_op, singleton(typed_ptr), singleton(ptr)), singleton(typed_ptr), NULL));
275+
}
276+
ops = change_node_at_index(a, ops, 0, ptr);
277+
r = prim_op_helper(a, lea_op, empty(a), ops);
278+
if (UNTYPED_POINTERS) {
279+
const Type* element_t = convert_type(p, LLVMGetGEPSourceElementType(instr));
280+
const Type* untyped_ptr_t = convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0)));
281+
bool idk;
282+
//element_t = qualified_type_helper(element_t, false);
283+
enter_composite(&element_t, &idk, nodes(a, ops.count - 2, &ops.nodes[2]), true);
284+
const Type* typed_ptr = type_untyped_ptr(untyped_ptr_t, element_t);
285+
r = prim_op_helper(a, reinterpret_op, singleton(untyped_ptr_t), BIND_PREV_R(typed_ptr));
286+
}
248287
break;
249288
}
250289
case LLVMTrunc:
@@ -493,6 +532,7 @@ EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVM
493532
LLVMAttributeRef attr = attrs[i];
494533
size_t k = LLVMGetEnumAttributeKind(attr);
495534
size_t e = LLVMGetEnumAttributeKindForName("byval", 5);
535+
uint64_t value = LLVMGetEnumAttributeValue(attr);
496536
// printf("p = %zu, i = %zu, k = %zu, e = %zu\n", param_index, i, k, e);
497537
if (k == e)
498538
decoded[param_index].is_byval = true;

src/frontends/llvm/l2s_meta.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static const Node* convert_named_tuple_metadata(Parser* p, LLVMValueRef v, Strin
2727
String name = LLVMGetValueName(v);
2828
if (!name || strlen(name) == 0)
2929
name = unique_name(a, node_name);
30-
Node* g = global_var(p->dst, singleton(annotation(a, (Annotation) { .name = "SkipOnInfer" })), NULL, name, AsDebugInfo);
30+
Node* g = global_var(p->dst, singleton(annotation(a, (Annotation) { .name = "LLVMMetaData" })), NULL, name, AsDebugInfo);
3131
const Node* r = ref_decl_helper(a, g);
3232
insert_dict(LLVMValueRef, const Type*, p->map, v, r);
3333

src/frontends/llvm/l2s_postprocess.c

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,36 @@ static Nodes remake_variables(Context* ctx, Nodes old) {
7979

8080
static const Node* process_op(Context* ctx, NodeClass op_class, String op_name, const Node* node) {
8181
IrArena* a = ctx->rewriter.dst_arena;
82+
Rewriter* r = &ctx->rewriter;
8283
switch (node->tag) {
83-
case Variable_TAG: return var(a, node->payload.var.type ? qualified_type_helper(rewrite_node(&ctx->rewriter, node->payload.var.type), false) : NULL, node->payload.var.name);
84+
case Variable_TAG: {
85+
assert(node->payload.var.type);
86+
if (node->payload.var.type->tag == QualifiedType_TAG)
87+
return var(a, node->payload.var.type ? rewrite_node(&ctx->rewriter, node->payload.var.type) : NULL, node->payload.var.name);
88+
return var(a, qualified_type_helper(rewrite_node(&ctx->rewriter, node->payload.var.type), false), node->payload.var.name);
89+
}
90+
case Block_TAG: {
91+
Nodes yield_types = rewrite_nodes(r, node->payload.block.yield_types);
92+
const Node* ninside = rewrite_node(r, node->payload.block.inside);
93+
const Node* term = get_abstraction_body(ninside);
94+
while (term->tag == Let_TAG) {
95+
term = get_abstraction_body(get_let_tail(term));
96+
}
97+
assert(term->tag == Yield_TAG);
98+
yield_types = get_values_types(a, term->payload.yield.args);
99+
return block(a, (Block) {
100+
.yield_types = yield_types,
101+
.inside = ninside,
102+
});
103+
}
104+
case Constant_TAG: {
105+
Node* new = recreate_node_identity(r, node);
106+
BodyBuilder* bb = begin_body(a);
107+
const Node* value = first(bind_instruction(bb, new->payload.constant.instruction));
108+
value = first(bind_instruction(bb, prim_op_helper(a, subgroup_assume_uniform_op, empty(a), singleton(value))));
109+
new->payload.constant.instruction = yield_values_and_wrap_in_block(bb, singleton(value));
110+
return new;
111+
}
84112
case Function_TAG: {
85113
Context fn_ctx = *ctx;
86114
fn_ctx.curr_scope = new_scope(node);
@@ -183,7 +211,7 @@ static const Node* process_op(Context* ctx, NodeClass op_class, String op_name,
183211
const Type* jp_type = join_point_type(a, (JoinPointType) {
184212
.yield_types = get_variables_types(a, get_abstraction_params(dst))
185213
});
186-
join_token = var(a, jp_type, get_abstraction_name(dst));
214+
join_token = var(a, qualified_type_helper(jp_type, false), get_abstraction_name(dst));
187215
controls->tokens = append_nodes(a, controls->tokens, join_token);
188216
controls->destinations = append_nodes(a, controls->destinations, dst);
189217
}
@@ -193,6 +221,7 @@ static const Node* process_op(Context* ctx, NodeClass op_class, String op_name,
193221
if (fn->tag == BasicBlock_TAG)
194222
fn = (Node*) fn->payload.basic_block.fn;
195223
assert(fn->tag == Function_TAG);
224+
fn = rewrite_node(r, fn);
196225
Node* wrapper = basic_block(a, fn, nparams, format_string_arena(a->arena, "wrapper_to_%s", get_abstraction_name(dst)));
197226
wrapper->payload.basic_block.body = join(a, (Join) {
198227
.args = nparams,
@@ -214,23 +243,25 @@ static const Node* process_op(Context* ctx, NodeClass op_class, String op_name,
214243
break;
215244
}
216245
case GlobalVariable_TAG: {
246+
if (lookup_annotation(node, "LLVMMetaData"))
247+
return NULL;
217248
AddressSpace as = node->payload.global_variable.address_space;
218249
const Node* old_init = node->payload.global_variable.init;
219-
Nodes annotations = rewrite_nodes(&ctx->rewriter, node->payload.global_variable.annotations);
220-
const Type* type = rewrite_node(&ctx->rewriter, node->payload.global_variable.type);
250+
Nodes annotations = rewrite_nodes(r, node->payload.global_variable.annotations);
251+
const Type* type = rewrite_node(r, node->payload.global_variable.type);
221252
ParsedAnnotation* an = find_annotation(ctx->p, node);
222253
while (an) {
223-
annotations = append_nodes(a, annotations, an->payload);
254+
annotations = append_nodes(a, annotations, rewrite_node(r, an->payload));
224255
if (strcmp(get_annotation_name(an->payload), "Builtin") == 0)
225256
old_init = NULL;
226257
if (strcmp(get_annotation_name(an->payload), "UniformConstant") == 0)
227258
as = AsUniformConstant;
228259
an = an->next;
229260
}
230261
Node* decl = global_var(ctx->rewriter.dst_module, annotations, type, get_declaration_name(node), as);
231-
register_processed(&ctx->rewriter, node, decl);
262+
register_processed(r, node, decl);
232263
if (old_init)
233-
decl->payload.global_variable.init = rewrite_node(&ctx->rewriter, old_init);
264+
decl->payload.global_variable.init = rewrite_node(r, old_init);
234265
return decl;
235266
}
236267
default: break;

src/frontends/llvm/l2s_type.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "l2s_private.h"
2+
#include "type.h"
23

34
#include "portability.h"
45
#include "log.h"
@@ -39,6 +40,8 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) {
3940
const Type* ret_type = convert_type(p, LLVMGetReturnType(t));
4041
if (LLVMGetTypeKind(LLVMGetReturnType(t)) == LLVMVoidTypeKind)
4142
ret_type = empty_multiple_return_type(a);
43+
else
44+
ret_type = qualified_type_helper(ret_type, false);
4245
return fn_type(a, (FnType) {
4346
.param_types = nodes(a, num_params, cparam_types),
4447
.return_types = ret_type == empty_multiple_return_type(a) ? empty(a) : singleton(ret_type)
@@ -96,6 +99,8 @@ const Type* convert_type(Parser* p, LLVMTypeRef t) {
9699
#if !UNTYPED_POINTERS
97100
LLVMTypeRef element_type = LLVMGetElementType(t);
98101
pointee = convert_type(p, element_type);
102+
#else
103+
pointee = unit_type(a);
99104
#endif
100105
return ptr_type(a, (PtrType) {
101106
.address_space = as,

src/frontends/llvm/l2s_value.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,13 @@ const Node* convert_value(Parser* p, LLVMValueRef v) {
8080
name = unique_name(a, "constant_expr");
8181
Nodes annotations = singleton(annotation(a, (Annotation) { .name = "SkipOnInfer" }));
8282
annotations = empty(a);
83-
Node* decl = constant(p->dst, annotations, NULL, name);
83+
assert(t);
84+
Node* decl = constant(p->dst, annotations, t, name);
8485
r = ref_decl_helper(a, decl);
8586
insert_dict(LLVMTypeRef, const Type*, p->map, v, r);
8687
BodyBuilder* bb = begin_body(a);
8788
EmittedInstr emitted = convert_instruction(p, NULL, bb, v);
88-
Nodes types = singleton(convert_type(p, LLVMTypeOf(v)));
89+
Nodes types = singleton(t);
8990
decl->payload.constant.instruction = bind_last_instruction_and_wrap_in_block_explicit_return_types(bb, emitted.instruction, &types);
9091
return r;
9192
}

src/shady/analysis/verify.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ static void verify_nominal_node(const Node* fn, const Node* n) {
7979
break;
8080
}
8181
case Constant_TAG: {
82-
const Type* t = n->payload.constant.instruction->type;
83-
bool u = deconstruct_qualified_type(&t);
84-
assert(u);
85-
assert(is_subtype(n->payload.constant.type_hint, t));
82+
if (n->payload.constant.instruction) {
83+
const Type* t = n->payload.constant.instruction->type;
84+
bool u = deconstruct_qualified_type(&t);
85+
assert(u);
86+
assert(is_subtype(n->payload.constant.type_hint, t));
87+
}
8688
break;
8789
}
8890
case GlobalVariable_TAG: {

src/shady/compile.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ CompilerConfig default_compiler_config() {
3030

3131
.logging = {
3232
// most of the time, we are not interested in seeing generated & internal code in the debug output
33-
.print_internal = true,
34-
.print_generated = true,
33+
//.print_internal = true,
34+
//.print_generated = true,
35+
.print_builtin = true,
3536
},
3637

3738
.optimisations = {

src/shady/node.c

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ const Node* resolve_node_to_definition(const Node* node, NodeResolveConfig confi
140140
case PrimOp_TAG: {
141141
switch (node->payload.prim_op.op) {
142142
case quote_op: {
143-
node = first(node->payload.prim_op.operands);;
143+
node = first(node->payload.prim_op.operands);
144144
continue;
145145
}
146146
case load_op: {
@@ -200,22 +200,19 @@ const char* get_string_literal(IrArena* arena, const Node* node) {
200200
if (!node)
201201
return NULL;
202202
switch (node->tag) {
203+
case Declaration_GlobalVariable_TAG: {
204+
const Node* init = node->payload.global_variable.init;
205+
if (init) {
206+
return get_string_literal(arena, init);
207+
}
208+
break;
209+
}
210+
case Declaration_Constant_TAG: {
211+
return get_string_literal(arena, node->payload.constant.instruction);
212+
}
203213
case RefDecl_TAG: {
204214
const Node* decl = node->payload.ref_decl.decl;
205-
switch (is_declaration(decl)) {
206-
case Declaration_GlobalVariable_TAG: {
207-
const Node* init = decl->payload.global_variable.init;
208-
if (init)
209-
return get_string_literal(arena, init);
210-
break;
211-
}
212-
case Declaration_Constant_TAG: {
213-
return get_string_literal(arena, decl->payload.constant.instruction);
214-
}
215-
default:
216-
break;
217-
}
218-
return NULL;
215+
return get_string_literal(arena, decl);
219216
}
220217
case PrimOp_TAG: {
221218
switch (node->payload.prim_op.op) {

0 commit comments

Comments
 (0)