Skip to content

Commit 4f614bc

Browse files
committed
coop: mulAdd instruction
1 parent fcae596 commit 4f614bc

File tree

18 files changed

+141
-9
lines changed

18 files changed

+141
-9
lines changed

naga/src/back/dot/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,12 @@ fn write_function_expressions(
742742
let ty = if committed { "Committed" } else { "Candidate" };
743743
(format!("get{ty}HitVertexPositions").into(), 4)
744744
}
745+
E::MulAdd { a, b, c } => {
746+
edges.insert("a", a);
747+
edges.insert("b", b);
748+
edges.insert("c", c);
749+
("MulAdd".into(), 6)
750+
}
745751
};
746752

747753
// give uniform expressions an outline

naga/src/back/glsl/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4341,7 +4341,8 @@ impl<'a, W: Write> Writer<'a, W> {
43414341
}
43424342
// not supported yet
43434343
Expression::RayQueryGetIntersection { .. }
4344-
| Expression::RayQueryVertexPositions { .. } => unreachable!(),
4344+
| Expression::RayQueryVertexPositions { .. }
4345+
| Expression::MulAdd { .. } => unreachable!(),
43454346
}
43464347

43474348
Ok(())

naga/src/back/hlsl/writer.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4275,7 +4275,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
42754275
}
42764276
}
42774277
// Not supported yet
4278-
Expression::RayQueryVertexPositions { .. } => unreachable!(),
4278+
Expression::RayQueryVertexPositions { .. } | Expression::MulAdd { .. } => {
4279+
unreachable!()
4280+
}
42794281
// Nothing to do here, since call expression already cached
42804282
Expression::CallResult(_)
42814283
| Expression::AtomicResult { .. }

naga/src/back/msl/writer.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,6 +2842,13 @@ impl<W: Write> Writer<W> {
28422842
}
28432843
write!(self.out, "}}")?;
28442844
}
2845+
crate::Expression::MulAdd { a, b, c } => {
2846+
self.put_expression(a, context, false)?;
2847+
write!(self.out, " * ")?;
2848+
self.put_expression(b, context, false)?;
2849+
write!(self.out, " + ")?;
2850+
self.put_expression(c, context, false)?;
2851+
}
28452852
}
28462853
Ok(())
28472854
}

naga/src/back/pipeline_constants.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,15 @@ fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut E
633633
} => {
634634
adjust(query);
635635
}
636+
Expression::MulAdd {
637+
ref mut a,
638+
ref mut b,
639+
ref mut c,
640+
} => {
641+
adjust(a);
642+
adjust(b);
643+
adjust(c);
644+
}
636645
}
637646
}
638647

naga/src/back/spv/block.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,17 @@ impl BlockContext<'_> {
18051805
)?;
18061806
self.write_ray_query_return_vertex_position(query, block, committed)
18071807
}
1808+
crate::Expression::MulAdd { a, b, c } => {
1809+
let id = self.gen_id();
1810+
block.body.push(Instruction::coop_mul_add(
1811+
result_type_id,
1812+
id,
1813+
self.cached[a],
1814+
self.cached[b],
1815+
self.cached[c],
1816+
));
1817+
id
1818+
}
18081819
};
18091820

18101821
self.cached[expr_handle] = id;

naga/src/back/spv/instructions.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,18 @@ impl super::Instruction {
12451245

12461246
instruction
12471247
}
1248+
1249+
// Cooperative operations
1250+
pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self {
1251+
let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR);
1252+
instruction.set_type(result_type_id);
1253+
instruction.set_result(id);
1254+
instruction.add_operand(a);
1255+
instruction.add_operand(b);
1256+
instruction.add_operand(c);
1257+
1258+
instruction
1259+
}
12481260
}
12491261

12501262
impl From<crate::StorageFormat> for spirv::ImageFormat {

naga/src/back/wgsl/writer.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,6 +1685,15 @@ impl<W: Write> Writer<W> {
16851685

16861686
write!(self.out, ")")?
16871687
}
1688+
Expression::MulAdd { a, b, c } => {
1689+
write!(self.out, "mulAdd(")?;
1690+
self.write_expr(module, a, func_ctx)?;
1691+
write!(self.out, ", ")?;
1692+
self.write_expr(module, b, func_ctx)?;
1693+
write!(self.out, ", ")?;
1694+
self.write_expr(module, c, func_ctx)?;
1695+
write!(self.out, ")")?
1696+
}
16881697
// Not supported yet
16891698
Expression::RayQueryGetIntersection { .. }
16901699
| Expression::RayQueryVertexPositions { .. } => unreachable!(),

naga/src/compact/expressions.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ impl ExpressionTracer<'_> {
253253
} => {
254254
self.expressions_used.insert(query);
255255
}
256+
Ex::MulAdd { a, b, c } => {
257+
self.expressions_used.insert_iter([a, b, c]);
258+
}
256259
}
257260
}
258261
}
@@ -419,6 +422,15 @@ impl ModuleMap {
419422
ref mut query,
420423
committed: _,
421424
} => adjust(query),
425+
Ex::MulAdd {
426+
ref mut a,
427+
ref mut b,
428+
ref mut c,
429+
} => {
430+
adjust(a);
431+
adjust(b);
432+
adjust(c);
433+
}
422434
}
423435
}
424436

naga/src/front/wgsl/lower/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3082,7 +3082,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
30823082
);
30833083
return Ok(Some(result));
30843084
}
3085-
30863085
"quadSwapY" => {
30873086
let mut args = ctx.prepare_args(arguments, 1, span);
30883087

@@ -3106,7 +3105,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31063105
);
31073106
return Ok(Some(result));
31083107
}
3109-
31103108
"quadSwapDiagonal" => {
31113109
let mut args = ctx.prepare_args(arguments, 1, span);
31123110

@@ -3130,6 +3128,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31303128
);
31313129
return Ok(Some(result));
31323130
}
3131+
"coopMulAdd" => {
3132+
let mut args = ctx.prepare_args(arguments, 3, span);
3133+
let a = self.expression(args.next()?, ctx)?;
3134+
let b = self.expression(args.next()?, ctx)?;
3135+
let c = self.expression(args.next()?, ctx)?;
3136+
args.finish()?;
3137+
3138+
ir::Expression::MulAdd { a, b, c }
3139+
}
31333140
_ => {
31343141
return Err(Box::new(Error::UnknownIdent(function.span, function.name)))
31353142
}

0 commit comments

Comments
 (0)