Skip to content

Commit b1fb29f

Browse files
committed
coop: wgsl parsing, IR role
1 parent 82df04d commit b1fb29f

22 files changed

+486
-20
lines changed

naga/src/back/msl/writer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ impl Display for TypeContext<'_> {
239239
columns,
240240
rows,
241241
scalar,
242+
role: _,
242243
} => {
243244
write!(
244245
out,

naga/src/back/spv/instructions.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,16 +284,17 @@ impl super::Instruction {
284284
pub(super) fn type_coop_matrix(
285285
id: Word,
286286
scalar_type_id: Word,
287-
row_count: crate::CooperativeVectorSize,
288-
column_count: crate::CooperativeVectorSize,
287+
row_count: crate::CooperativeSize,
288+
column_count: crate::CooperativeSize,
289+
role: spirv::CooperativeMatrixUse,
289290
) -> Self {
290291
let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR);
291292
instruction.set_result(id);
292293
instruction.add_operand(scalar_type_id);
293294
instruction.add_operand(spirv::Scope::Subgroup as u32);
294295
instruction.add_operand(column_count as u32);
295296
instruction.add_operand(row_count as u32);
296-
instruction.add_operand(spirv::CooperativeMatrixUse::MatrixAKHR as u32); //TODO: configure or expose
297+
instruction.add_operand(role as u32);
297298
instruction
298299
}
299300

@@ -1305,3 +1306,13 @@ impl From<crate::ImageDimension> for spirv::Dim {
13051306
}
13061307
}
13071308
}
1309+
1310+
impl From<crate::CooperativeRole> for spirv::CooperativeMatrixUse {
1311+
fn from(role: crate::CooperativeRole) -> Self {
1312+
match role {
1313+
crate::CooperativeRole::A => Self::MatrixAKHR,
1314+
crate::CooperativeRole::B => Self::MatrixBKHR,
1315+
crate::CooperativeRole::C => Self::MatrixAccumulatorKHR,
1316+
}
1317+
}
1318+
}

naga/src/back/spv/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,10 @@ impl NumericType {
342342
#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
343343
enum CooperativeType {
344344
Matrix {
345-
columns: crate::CooperativeVectorSize,
346-
rows: crate::CooperativeVectorSize,
345+
columns: crate::CooperativeSize,
346+
rows: crate::CooperativeSize,
347347
scalar: crate::CooperativeScalar,
348+
role: crate::CooperativeRole,
348349
},
349350
}
350351

@@ -355,10 +356,12 @@ impl CooperativeType {
355356
columns,
356357
rows,
357358
scalar,
359+
role,
358360
} => Some(Self::Matrix {
359361
columns,
360362
rows,
361363
scalar,
364+
role,
362365
}),
363366
_ => None,
364367
}

naga/src/back/spv/writer.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1380,9 +1380,10 @@ impl Writer {
13801380
columns,
13811381
rows,
13821382
scalar,
1383+
role,
13831384
} => {
13841385
let scalar_id = self.get_cooperative_type_id(scalar);
1385-
Instruction::type_coop_matrix(id, scalar_id, rows, columns)
1386+
Instruction::type_coop_matrix(id, scalar_id, rows, columns, role.into())
13861387
}
13871388
};
13881389

naga/src/common/wgsl/to_wgsl.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,23 @@ impl TryToWgsl for crate::CooperativeScalar {
318318
}
319319
}
320320

321-
impl ToWgsl for crate::ImageDimension {
321+
impl ToWgsl for crate::CooperativeRole {
322322
fn to_wgsl(self) -> &'static str {
323-
use crate::ImageDimension as IDim;
323+
match self {
324+
Self::A => "A",
325+
Self::B => "B",
326+
Self::C => "C",
327+
}
328+
}
329+
}
324330

331+
impl ToWgsl for crate::ImageDimension {
332+
fn to_wgsl(self) -> &'static str {
325333
match self {
326-
IDim::D1 => "1d",
327-
IDim::D2 => "2d",
328-
IDim::D3 => "3d",
329-
IDim::Cube => "cube",
334+
Self::D1 => "1d",
335+
Self::D2 => "2d",
336+
Self::D3 => "3d",
337+
Self::Cube => "cube",
330338
}
331339
}
332340
}

naga/src/common/wgsl/types.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,15 @@ where
321321
columns,
322322
rows,
323323
scalar,
324+
role,
324325
} => {
325326
write!(
326327
out,
327-
"coop_mat{}x{}<{}>",
328+
"coop_mat{}x{}<{},{}>",
328329
columns as u32,
329330
rows as u32,
330-
scalar.try_to_wgsl().unwrap_or_default()
331+
scalar.try_to_wgsl().unwrap_or_default(),
332+
role.to_wgsl(),
331333
)?;
332334
}
333335
TypeInner::Pointer { base, space } => {

naga/src/front/wgsl/error.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ pub(crate) enum Error<'a> {
416416
TypeTooLarge {
417417
span: Span,
418418
},
419+
UnderspecifiedCooperativeMatrix,
420+
UnknownCooperativeScalar(Span),
419421
}
420422

421423
impl From<ConflictingDiagnosticRuleError> for Error<'_> {
@@ -1390,6 +1392,16 @@ impl<'a> Error<'a> {
13901392
crate::valid::MAX_TYPE_SIZE
13911393
)],
13921394
},
1395+
Error::UnderspecifiedCooperativeMatrix => ParseError {
1396+
message: "cooperative matrix constructor is underspecified".into(),
1397+
labels: vec![],
1398+
notes: vec![format!("must be F32")],
1399+
},
1400+
Error::UnknownCooperativeScalar(span) => ParseError {
1401+
message: "unknown cooperative scalar type".into(),
1402+
labels: vec![(span, "type needs the scalar type specified".into())],
1403+
notes: vec![format!("must be F32")],
1404+
},
13931405
}
13941406
}
13951407
}

naga/src/front/wgsl/lower/construction.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,32 @@ impl<'source> Lowerer<'source, '_> {
638638
};
639639
Constructor::Type(ty)
640640
}
641+
ast::ConstructorType::PartialCooperativeMatrix { .. } => {
642+
return Err(Box::new(Error::UnderspecifiedCooperativeMatrix));
643+
}
644+
ast::ConstructorType::CooperativeMatrix {
645+
rows,
646+
columns,
647+
ty,
648+
ty_span,
649+
role,
650+
} => {
651+
let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
652+
let scalar = match ctx.module.types[ty].inner {
653+
crate::TypeInner::Scalar(crate::Scalar {
654+
kind: crate::ScalarKind::Float,
655+
width: 4,
656+
}) => crate::CooperativeScalar::F32,
657+
_ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))),
658+
};
659+
let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix {
660+
columns,
661+
rows,
662+
scalar,
663+
role,
664+
});
665+
Constructor::Type(ty)
666+
}
641667
ast::ConstructorType::PartialArray => Constructor::PartialArray,
642668
ast::ConstructorType::Array { base, size } => {
643669
let base = self.resolve_ast_type(base, &mut ctx.as_const())?;

naga/src/front/wgsl/lower/mod.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3955,6 +3955,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
39553955
_ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))),
39563956
}
39573957
}
3958+
ast::Type::CooperativeMatrix {
3959+
columns,
3960+
rows,
3961+
ty,
3962+
ty_span,
3963+
role,
3964+
} => {
3965+
let ty = self.resolve_ast_type(ty, ctx)?;
3966+
let scalar = match ctx.module.types[ty].inner {
3967+
ir::TypeInner::Scalar(crate::Scalar {
3968+
kind: crate::ScalarKind::Float,
3969+
width: 4,
3970+
}) => crate::CooperativeScalar::F32,
3971+
_ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))),
3972+
};
3973+
ir::TypeInner::CooperativeMatrix {
3974+
columns,
3975+
rows,
3976+
scalar,
3977+
role,
3978+
}
3979+
}
39583980
ast::Type::Atomic(scalar) => scalar.to_inner_atomic(),
39593981
ast::Type::Pointer { base, space } => {
39603982
let base = self.resolve_ast_type(base, ctx)?;

naga/src/front/wgsl/parse/ast.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ pub enum Type<'a> {
235235
ty: Handle<Type<'a>>,
236236
ty_span: Span,
237237
},
238+
CooperativeMatrix {
239+
columns: crate::CooperativeSize,
240+
rows: crate::CooperativeSize,
241+
ty: Handle<Type<'a>>,
242+
ty_span: Span,
243+
role: crate::CooperativeRole,
244+
},
238245
Atomic(Scalar),
239246
Pointer {
240247
base: Handle<Type<'a>>,
@@ -385,6 +392,21 @@ pub enum ConstructorType<'a> {
385392
ty_span: Span,
386393
},
387394

395+
/// A cooperative matrix construction base `coop_mat8x8(...)`.
396+
PartialCooperativeMatrix {
397+
columns: crate::CooperativeSize,
398+
rows: crate::CooperativeSize,
399+
},
400+
401+
/// A full cooperative matrix construction `coop_mat8x8<f32,A>(...)`.
402+
CooperativeMatrix {
403+
columns: crate::CooperativeSize,
404+
rows: crate::CooperativeSize,
405+
ty: Handle<Type<'a>>,
406+
ty_span: Span,
407+
role: crate::CooperativeRole,
408+
},
409+
388410
/// An array whose component type and size are inferred from the arguments:
389411
/// `array(3,4,5)`.
390412
PartialArray,

0 commit comments

Comments
 (0)