Skip to content

Commit 0a3cba4

Browse files
committed
Add Cooperative* type to IR
1 parent ab57c10 commit 0a3cba4

File tree

13 files changed

+145
-2
lines changed

13 files changed

+145
-2
lines changed

naga/src/back/glsl/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,8 @@ impl<'a, W: Write> Writer<'a, W> {
11071107
TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?,
11081108
// Write all variants instead of `_` so that if new variants are added a
11091109
// no exhaustiveness error is thrown
1110-
TypeInner::Pointer { .. }
1110+
TypeInner::CooperativeMatrix { .. }
1111+
| TypeInner::Pointer { .. }
11111112
| TypeInner::Struct { .. }
11121113
| TypeInner::Image { .. }
11131114
| TypeInner::Sampler { .. }

naga/src/back/msl/writer.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,20 @@ impl Display for TypeContext<'_> {
235235
rows,
236236
scalar,
237237
} => put_numeric_type(out, scalar, &[rows, columns]),
238+
crate::TypeInner::CooperativeMatrix {
239+
columns,
240+
rows,
241+
scalar,
242+
} => {
243+
write!(
244+
out,
245+
"{}::simdgroup_{}{}x{}",
246+
NAMESPACE,
247+
scalar.to_msl_name(),
248+
columns as u32,
249+
rows as u32,
250+
)
251+
}
238252
crate::TypeInner::Pointer { base, space } => {
239253
let sub = Self {
240254
handle: base,
@@ -528,6 +542,14 @@ impl crate::Scalar {
528542
}
529543
}
530544

545+
impl crate::CooperativeScalar {
546+
const fn to_msl_name(self) -> &'static str {
547+
match self {
548+
Self::F32 => "float",
549+
}
550+
}
551+
}
552+
531553
const fn separate(need_separator: bool) -> &'static str {
532554
if need_separator {
533555
","
@@ -637,6 +659,7 @@ impl crate::Type {
637659
Ti::Scalar(_)
638660
| Ti::Vector { .. }
639661
| Ti::Matrix { .. }
662+
| Ti::CooperativeMatrix { .. }
640663
| Ti::Atomic(_)
641664
| Ti::Pointer { .. }
642665
| Ti::ValuePointer { .. } => self.name.is_some(),

naga/src/back/spv/writer.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ impl Writer {
436436
// these cases, so unwrap.
437437
LocalType::Numeric(NumericType::from_inner(inner).unwrap())
438438
}
439+
crate::TypeInner::CooperativeMatrix { .. } => return None,
439440
crate::TypeInner::Pointer { base, space } => {
440441
let base_type_id = self.get_handle_type_id(base);
441442
LocalType::Pointer {
@@ -1500,6 +1501,7 @@ impl Writer {
15001501
| crate::TypeInner::Atomic(_)
15011502
| crate::TypeInner::Vector { .. }
15021503
| crate::TypeInner::Matrix { .. }
1504+
| crate::TypeInner::CooperativeMatrix { .. }
15031505
| crate::TypeInner::Pointer { .. }
15041506
| crate::TypeInner::ValuePointer { .. }
15051507
| crate::TypeInner::Image { .. }

naga/src/common/wgsl/to_wgsl.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,25 @@ impl TryToWgsl for crate::Scalar {
299299
}
300300
}
301301

302+
impl TryToWgsl for crate::CooperativeScalar {
303+
const DESCRIPTION: &'static str = "cooperative scalar type";
304+
305+
fn try_to_wgsl(self) -> Option<&'static str> {
306+
use crate::CooperativeScalar;
307+
308+
Some(match self {
309+
CooperativeScalar::F32 => "f32",
310+
})
311+
}
312+
313+
fn to_wgsl_for_diagnostics(self) -> String {
314+
match self.try_to_wgsl() {
315+
Some(static_string) => static_string.to_string(),
316+
None => unreachable!(),
317+
}
318+
}
319+
}
320+
302321
impl ToWgsl for crate::ImageDimension {
303322
fn to_wgsl(self) -> &'static str {
304323
use crate::ImageDimension as IDim;

naga/src/common/wgsl/types.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,19 @@ where
317317
ctx.write_scalar(scalar, out)?;
318318
out.write_str(">")?;
319319
}
320+
TypeInner::CooperativeMatrix {
321+
columns,
322+
rows,
323+
scalar,
324+
} => {
325+
write!(
326+
out,
327+
"coop_mat{}x{}<{}>",
328+
columns as u32,
329+
rows as u32,
330+
scalar.try_to_wgsl().unwrap_or_default()
331+
)?;
332+
}
320333
TypeInner::Pointer { base, space } => {
321334
let (address, maybe_access) = address_space_str(space);
322335
// Everything but `AddressSpace::Handle` gives us a `address` name, but

naga/src/compact/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ impl TypeTracer<'_> {
1616
Ti::Scalar { .. }
1717
| Ti::Vector { .. }
1818
| Ti::Matrix { .. }
19+
| Ti::CooperativeMatrix { .. }
1920
| Ti::Atomic { .. }
2021
| Ti::ValuePointer { .. }
2122
| Ti::Image { .. }
@@ -66,6 +67,7 @@ impl ModuleMap {
6667
Ti::Scalar(_)
6768
| Ti::Vector { .. }
6869
| Ti::Matrix { .. }
70+
| Ti::CooperativeMatrix { .. }
6971
| Ti::Atomic(_)
7072
| Ti::ValuePointer { .. }
7173
| Ti::Image { .. }

naga/src/front/wgsl/lower/conversion.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ impl crate::TypeInner {
350350
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => {
351351
Some(scalar)
352352
}
353+
Ti::CooperativeMatrix { .. } => None,
353354
Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types),
354355
Ti::Atomic(_)
355356
| Ti::Pointer { .. }
@@ -375,6 +376,7 @@ impl crate::TypeInner {
375376
Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => {
376377
Some(scalar)
377378
}
379+
Ti::CooperativeMatrix { .. } => None,
378380
Ti::Atomic(_) => None,
379381
Ti::Pointer { base, .. } | Ti::Array { base, .. } => {
380382
types[base].inner.automatically_convertible_scalar(types)

naga/src/ir/mod.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,16 @@ impl From<VectorSize> for u32 {
437437
}
438438
}
439439

440+
/// Number of components in a cooperative vector.
441+
#[repr(u8)]
442+
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
443+
#[cfg_attr(feature = "serialize", derive(Serialize))]
444+
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
445+
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
446+
pub enum CooperativeVectorSize {
447+
Eight = 8,
448+
}
449+
440450
/// Primitive type for a scalar.
441451
#[repr(u8)]
442452
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
@@ -464,6 +474,24 @@ pub enum ScalarKind {
464474
AbstractFloat,
465475
}
466476

477+
/// Primitive type for a scalar.
478+
#[repr(u8)]
479+
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
480+
#[cfg_attr(feature = "serialize", derive(Serialize))]
481+
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
482+
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
483+
pub enum CooperativeScalar {
484+
F32,
485+
}
486+
487+
impl CooperativeScalar {
488+
pub const fn width(&self) -> Bytes {
489+
match *self {
490+
Self::F32 => 4,
491+
}
492+
}
493+
}
494+
467495
/// Characteristics of a scalar type.
468496
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
469497
#[cfg_attr(feature = "serialize", derive(Serialize))]
@@ -712,6 +740,13 @@ pub enum TypeInner {
712740
rows: VectorSize,
713741
scalar: Scalar,
714742
},
743+
/// Matrix that is cooperatively processed by all the threads
744+
/// in an opaque mapping.
745+
CooperativeMatrix {
746+
columns: CooperativeVectorSize,
747+
rows: CooperativeVectorSize,
748+
scalar: CooperativeScalar,
749+
},
715750
/// Atomic scalar.
716751
Atomic(Scalar),
717752
/// Pointer to another type.

naga/src/proc/layouter.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ impl From<crate::VectorSize> for Alignment {
8686
}
8787
}
8888

89+
impl From<crate::CooperativeVectorSize> for Alignment {
90+
fn from(size: crate::CooperativeVectorSize) -> Self {
91+
Self(unsafe { NonZeroU32::new_unchecked(size as u32) })
92+
}
93+
}
94+
8995
/// Size and alignment information for a type.
9096
#[derive(Clone, Copy, Debug, Hash, PartialEq)]
9197
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
@@ -212,6 +218,18 @@ impl Layouter {
212218
alignment: Alignment::from(rows) * alignment,
213219
}
214220
}
221+
Ti::CooperativeMatrix {
222+
columns: _,
223+
rows,
224+
scalar,
225+
} => {
226+
let alignment = Alignment::new(scalar.width() as u32)
227+
.ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
228+
TypeLayout {
229+
size,
230+
alignment: Alignment::from(rows) * alignment,
231+
}
232+
}
215233
Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
216234
size,
217235
alignment: Alignment::ONE,

naga/src/proc/type_methods.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ impl crate::TypeInner {
202202
rows,
203203
scalar,
204204
} => Some(super::Alignment::from(rows) * scalar.width as u32 * columns as u32),
205+
Self::CooperativeMatrix {
206+
columns,
207+
rows,
208+
scalar,
209+
} => Some(columns as u32 * rows as u32 * scalar.width() as u32),
205210
Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN),
206211
Self::Array {
207212
base: _,
@@ -361,6 +366,7 @@ impl crate::TypeInner {
361366
crate::TypeInner::Scalar(scalar) => Some((None, scalar)),
362367
crate::TypeInner::Vector { size, scalar } => Some((Some(size), scalar)),
363368
crate::TypeInner::Matrix { .. }
369+
| crate::TypeInner::CooperativeMatrix { .. }
364370
| crate::TypeInner::Atomic(_)
365371
| crate::TypeInner::Pointer { .. }
366372
| crate::TypeInner::ValuePointer { .. }
@@ -385,7 +391,8 @@ impl crate::TypeInner {
385391
| crate::TypeInner::Matrix { scalar, .. }
386392
| crate::TypeInner::Atomic(scalar) => scalar.is_abstract(),
387393
crate::TypeInner::Array { base, .. } => types[base].inner.is_abstract(types),
388-
crate::TypeInner::ValuePointer { .. }
394+
crate::TypeInner::CooperativeMatrix { .. }
395+
| crate::TypeInner::ValuePointer { .. }
389396
| crate::TypeInner::Pointer { .. }
390397
| crate::TypeInner::Struct { .. }
391398
| crate::TypeInner::Image { .. }

0 commit comments

Comments
 (0)