Skip to content

Commit 5cea342

Browse files
committed
implement an FMA op
1 parent 72dee1b commit 5cea342

File tree

5 files changed

+48
-6
lines changed

5 files changed

+48
-6
lines changed

include/shady/primops.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666
"name": "div",
6767
"class": "arithmetic"
6868
},
69+
{
70+
"name": "fma",
71+
"class": "arithmetic"
72+
},
6973
{
7074
"name": "mod",
7175
"class": "arithmetic"

src/frontends/llvm/l2s_instr.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,9 @@ EmittedInstr convert_instruction(Parser* p, Node* fn_or_bb, BodyBuilder* b, LLVM
497497
} else if (string_starts_with(intrinsic, "llvm.fmuladd")) {
498498
Nodes ops = convert_operands(p, num_ops, instr);
499499
num_results = 1;
500-
r = prim_op_helper(a, mul_op, empty(a), nodes(a, 2, ops.nodes));
501-
r = prim_op_helper(a, add_op, empty(a), mk_nodes(a, first(BIND_PREV_R(convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))))), ops.nodes[2]));
500+
r = prim_op_helper(a, fma_op, empty(a), nodes(a, 3, ops.nodes));
501+
// r = prim_op_helper(a, mul_op, empty(a), nodes(a, 2, ops.nodes));
502+
// r = prim_op_helper(a, add_op, empty(a), mk_nodes(a, first(BIND_PREV_R(convert_type(p, LLVMTypeOf(LLVMGetOperand(instr, 0))))), ops.nodes[2]));
502503
break;
503504
} else if (string_starts_with(intrinsic, "llvm.fabs")) {
504505
Nodes ops = convert_operands(p, num_ops, instr);

src/shady/emit/c/emit_c_instructions.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,23 @@ static void emit_primop(Emitter* emitter, Printer* p, const Node* node, Instruct
270270
term = term_from_cvalue(format_string_arena(arena->arena, "(%s > 0 ? 1 : -1)", src));
271271
break;
272272
}
273+
case fma_op: {
274+
CValue a = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[0]));
275+
CValue b = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[1]));
276+
CValue c = to_cvalue(emitter, emit_value(emitter, p, prim_op->operands.nodes[2]));
277+
switch (emitter->config.dialect) {
278+
case CDialect_C11:
279+
case CDialect_CUDA: {
280+
term = term_from_cvalue(format_string_arena(arena->arena, "fmaf(%s, %s, %s)", a, b, c));
281+
break;
282+
}
283+
default: {
284+
term = term_from_cvalue(format_string_arena(arena->arena, "(%s * %s) + %s", a, b, c));
285+
break;
286+
}
287+
}
288+
break;
289+
}
273290
case lshift_op:
274291
case rshift_arithm_op:
275292
case rshift_logical_op: {

src/shady/emit/spirv/emit_spv_instructions.c

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,11 @@ static const IselTableEntry isel_table[] = {
124124
[abs_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = { (SpvOp) GLSLstd450SAbs, ISEL_ILLEGAL, (SpvOp) GLSLstd450FAbs, ISEL_ILLEGAL }},
125125
[sign_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = { (SpvOp) GLSLstd450SSign, ISEL_ILLEGAL, (SpvOp) GLSLstd450FSign, ISEL_ILLEGAL }},
126126

127-
[min_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMin, (SpvOp) GLSLstd450UMin, (SpvOp) GLSLstd450FMin, ISEL_ILLEGAL, ISEL_ILLEGAL }},
128-
[max_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMax, (SpvOp) GLSLstd450UMax, (SpvOp) GLSLstd450FMax, ISEL_ILLEGAL, ISEL_ILLEGAL }},
129-
[exp_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Exp},
130-
[pow_op] = {Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Pow},
127+
[min_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMin, (SpvOp) GLSLstd450UMin, (SpvOp) GLSLstd450FMin, ISEL_ILLEGAL, ISEL_ILLEGAL }},
128+
[max_op] = { Plain, FirstOp, Same, .extended_set = "GLSL.std.450", .fo = {(SpvOp) GLSLstd450SMax, (SpvOp) GLSLstd450UMax, (SpvOp) GLSLstd450FMax, ISEL_ILLEGAL, ISEL_ILLEGAL }},
129+
[exp_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Exp },
130+
[pow_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Pow },
131+
[fma_op] = { Plain, Monomorphic, Same, .extended_set = "GLSL.std.450", .op = (SpvOp) GLSLstd450Fma },
131132

132133
[debug_printf_op] = {Plain, Monomorphic, Void, .extended_set = "NonSemantic.DebugPrintf", .op = (SpvOp) NonSemanticDebugPrintfDebugPrintf},
133134

src/shady/type.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,25 @@ const Type* check_type_prim_op(IrArena* arena, PrimOp prim_op) {
698698

699699
return qualified_type_helper(first_operand_type, result_uniform);
700700
}
701+
case fma_op: {
702+
assert(prim_op.type_arguments.count == 0);
703+
assert(prim_op.operands.count == 3);
704+
const Type* first_operand_type = get_unqualified_type(first(prim_op.operands)->type);
705+
706+
bool result_uniform = true;
707+
for (size_t i = 0; i < prim_op.operands.count; i++) {
708+
const Node* arg = prim_op.operands.nodes[i];
709+
const Type* operand_type = arg->type;
710+
bool operand_uniform = deconstruct_qualified_type(&operand_type);
711+
712+
assert(get_maybe_packed_type_element(operand_type)->tag == Float_TAG);
713+
assert(first_operand_type == operand_type && "operand type mismatch");
714+
715+
result_uniform &= operand_uniform;
716+
}
717+
718+
return qualified_type_helper(first_operand_type, result_uniform);
719+
}
701720
case abs_op:
702721
case sign_op:
703722
{

0 commit comments

Comments
 (0)