@@ -739,20 +739,31 @@ impl<'tcx> CodegenCx<'tcx> {
739739 . decorate ( var_id. unwrap ( ) , Decoration :: Invariant , std:: iter:: empty ( ) ) ;
740740 }
741741 if let Some ( per_primitive_ext) = attrs. per_primitive_ext {
742- if storage_class != Ok ( StorageClass :: Output ) {
743- self . tcx . dcx ( ) . span_fatal (
744- per_primitive_ext. span ,
745- "`#[spirv(per_primitive_ext)]` is only valid on Output variables" ,
746- ) ;
747- }
748- if !( execution_model == ExecutionModel :: MeshEXT
749- || execution_model == ExecutionModel :: MeshNV )
750- {
751- self . tcx . dcx ( ) . span_fatal (
752- per_primitive_ext. span ,
753- "`#[spirv(per_primitive_ext)]` is only valid in mesh shaders" ,
754- ) ;
742+ match execution_model {
743+ ExecutionModel :: Fragment => {
744+ if storage_class != Ok ( StorageClass :: Input ) {
745+ self . tcx . dcx ( ) . span_fatal (
746+ per_primitive_ext. span ,
747+ "`#[spirv(per_primitive_ext)]` in fragment shaders is only valid on Input variables" ,
748+ ) ;
749+ }
750+ }
751+ ExecutionModel :: MeshNV | ExecutionModel :: MeshEXT => {
752+ if storage_class != Ok ( StorageClass :: Output ) {
753+ self . tcx . dcx ( ) . span_fatal (
754+ per_primitive_ext. span ,
755+ "`#[spirv(per_primitive_ext)]` in mesh shaders is only valid on Output variables" ,
756+ ) ;
757+ }
758+ }
759+ _ => {
760+ self . tcx . dcx ( ) . span_fatal (
761+ per_primitive_ext. span ,
762+ "`#[spirv(per_primitive_ext)]` is only valid in fragment or mesh shaders" ,
763+ ) ;
764+ }
755765 }
766+
756767 self . emit_global ( ) . decorate (
757768 var_id. unwrap ( ) ,
758769 Decoration :: PerPrimitiveEXT ,
0 commit comments