Skip to content

Commit 6f193db

Browse files
committed
tool: move attr parsing to mod attr
1 parent 29ba02d commit 6f193db

File tree

2 files changed

+332
-335
lines changed

2 files changed

+332
-335
lines changed

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 321 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
use crate::codegen_cx::CodegenCx;
66
use crate::symbols::Symbols;
77
use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
8+
use rustc_ast::{LitKind, MetaItemInner, MetaItemLit};
89
use rustc_hir as hir;
910
use rustc_hir::def_id::LocalModDefId;
1011
use rustc_hir::intravisit::{self, Visitor};
1112
use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
1213
use rustc_middle::hir::nested_filter;
1314
use rustc_middle::query::Providers;
1415
use rustc_middle::ty::TyCtxt;
15-
use rustc_span::{Span, Symbol};
16+
use rustc_span::{Ident, Span, Symbol};
1617
use std::rc::Rc;
1718

1819
// FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
@@ -152,7 +153,7 @@ impl AggregatedSpirvAttributes {
152153

153154
// NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
154155
// to see an attribute error, it will cause an ICE instead.
155-
for parse_attr_result in crate::symbols::parse_attrs_for_checking(&cx.sym, attrs) {
156+
for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) {
156157
let (span, parsed_attr) = match parse_attr_result {
157158
Ok(span_and_parsed_attr) => span_and_parsed_attr,
158159
Err((span, msg)) => {
@@ -278,7 +279,7 @@ impl CheckSpirvAttrVisitor<'_> {
278279
fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
279280
let mut aggregated_attrs = AggregatedSpirvAttributes::default();
280281

281-
let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs);
282+
let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs);
282283

283284
let attrs = self.tcx.hir_attrs(hir_id);
284285
for parse_attr_result in parse_attrs(attrs) {
@@ -512,3 +513,320 @@ pub(crate) fn provide(providers: &mut Providers) {
512513
..*providers
513514
};
514515
}
516+
517+
// FIXME(eddyb) find something nicer for the error type.
518+
type ParseAttrError = (Span, String);
519+
520+
fn parse_attrs_for_checking<'a>(
521+
sym: &'a Symbols,
522+
attrs: &'a [Attribute],
523+
) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a {
524+
attrs.iter().flat_map(move |attr| {
525+
let (whole_attr_error, args) = match attr {
526+
Attribute::Unparsed(item) => {
527+
// #[...]
528+
let s = &item.path.segments;
529+
if s.len() > 1 && s[0].name == sym.rust_gpu {
530+
// #[rust_gpu ...]
531+
if s.len() != 2 || s[1].name != sym.spirv {
532+
// #[rust_gpu::...] but not #[rust_gpu::spirv]
533+
(
534+
Some(Err((
535+
attr.span(),
536+
"unknown `rust_gpu` attribute, expected `rust_gpu::spirv`"
537+
.to_string(),
538+
))),
539+
Default::default(),
540+
)
541+
} else if let Some(args) = attr.meta_item_list() {
542+
// #[rust_gpu::spirv(...)]
543+
(None, args)
544+
} else {
545+
// #[rust_gpu::spirv]
546+
(
547+
Some(Err((
548+
attr.span(),
549+
"#[rust_gpu::spirv(..)] attribute must have at least one argument"
550+
.to_string(),
551+
))),
552+
Default::default(),
553+
)
554+
}
555+
} else {
556+
// #[...] but not #[rust_gpu ...]
557+
(None, Default::default())
558+
}
559+
}
560+
Attribute::Parsed(_) => (None, Default::default()),
561+
};
562+
563+
whole_attr_error
564+
.into_iter()
565+
.chain(args.into_iter().map(move |ref arg| {
566+
let span = arg.span();
567+
let parsed_attr = if arg.has_name(sym.descriptor_set) {
568+
SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?)
569+
} else if arg.has_name(sym.binding) {
570+
SpirvAttribute::Binding(parse_attr_int_value(arg)?)
571+
} else if arg.has_name(sym.input_attachment_index) {
572+
SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?)
573+
} else if arg.has_name(sym.spec_constant) {
574+
SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?)
575+
} else {
576+
let name = match arg.ident() {
577+
Some(i) => i,
578+
None => {
579+
return Err((
580+
span,
581+
"#[spirv(..)] attribute argument must be single identifier"
582+
.to_string(),
583+
));
584+
}
585+
};
586+
sym.attributes.get(&name.name).map_or_else(
587+
|| Err((name.span, "unknown argument to spirv attribute".to_string())),
588+
|a| {
589+
Ok(match a {
590+
SpirvAttribute::Entry(entry) => SpirvAttribute::Entry(
591+
parse_entry_attrs(sym, arg, &name, entry.execution_model)?,
592+
),
593+
_ => a.clone(),
594+
})
595+
},
596+
)?
597+
};
598+
Ok((span, parsed_attr))
599+
}))
600+
})
601+
}
602+
603+
fn parse_spec_constant_attr(
604+
sym: &Symbols,
605+
arg: &MetaItemInner,
606+
) -> Result<SpecConstant, ParseAttrError> {
607+
let mut id = None;
608+
let mut default = None;
609+
610+
if let Some(attrs) = arg.meta_item_list() {
611+
for attr in attrs {
612+
if attr.has_name(sym.id) {
613+
if id.is_none() {
614+
id = Some(parse_attr_int_value(attr)?);
615+
} else {
616+
return Err((attr.span(), "`id` may only be specified once".into()));
617+
}
618+
} else if attr.has_name(sym.default) {
619+
if default.is_none() {
620+
default = Some(parse_attr_int_value(attr)?);
621+
} else {
622+
return Err((attr.span(), "`default` may only be specified once".into()));
623+
}
624+
} else {
625+
return Err((attr.span(), "expected `id = ...` or `default = ...`".into()));
626+
}
627+
}
628+
}
629+
Ok(SpecConstant {
630+
id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
631+
default,
632+
})
633+
}
634+
635+
fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> {
636+
let arg = match arg.meta_item() {
637+
Some(arg) => arg,
638+
None => return Err((arg.span(), "attribute must have value".to_string())),
639+
};
640+
match arg.name_value_literal() {
641+
Some(&MetaItemLit {
642+
kind: LitKind::Int(x, ..),
643+
..
644+
}) if x <= u32::MAX as u128 => Ok(x.get() as u32),
645+
_ => Err((arg.span, "attribute value must be integer".to_string())),
646+
}
647+
}
648+
649+
fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> {
650+
let arg = match arg.meta_item() {
651+
Some(arg) => arg,
652+
None => return Err((arg.span(), "attribute must have value".to_string())),
653+
};
654+
match arg.meta_item_list() {
655+
Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
656+
let mut local_size = [1; 3];
657+
for (idx, lit) in tuple.iter().enumerate() {
658+
match lit {
659+
MetaItemInner::Lit(MetaItemLit {
660+
kind: LitKind::Int(x, ..),
661+
..
662+
}) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32,
663+
_ => return Err((lit.span(), "must be a u32 literal".to_string())),
664+
}
665+
}
666+
Ok(local_size)
667+
}
668+
Some([]) => Err((
669+
arg.span,
670+
"#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
671+
)),
672+
Some(tuple) if tuple.len() > 3 => Err((
673+
arg.span,
674+
"#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
675+
)),
676+
_ => Err((
677+
arg.span,
678+
"#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
679+
)),
680+
}
681+
}
682+
683+
// for a given entry, gather up the additional attributes
684+
// in this case ExecutionMode's, some have extra arguments
685+
// others are specified with x, y, or z components
686+
// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))]
687+
fn parse_entry_attrs(
688+
sym: &Symbols,
689+
arg: &MetaItemInner,
690+
name: &Ident,
691+
execution_model: ExecutionModel,
692+
) -> Result<Entry, ParseAttrError> {
693+
use ExecutionMode::*;
694+
use ExecutionModel::*;
695+
let mut entry = Entry::from(execution_model);
696+
let mut origin_mode: Option<ExecutionMode> = None;
697+
let mut local_size: Option<[u32; 3]> = None;
698+
let mut local_size_hint: Option<[u32; 3]> = None;
699+
// Reserved
700+
//let mut max_workgroup_size_intel: Option<[u32; 3]> = None;
701+
if let Some(attrs) = arg.meta_item_list() {
702+
for attr in attrs {
703+
if let Some(attr_name) = attr.ident() {
704+
if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name)
705+
{
706+
use crate::symbols::ExecutionModeExtraDim::*;
707+
let val = match extra_dim {
708+
None | Tuple => Option::None,
709+
_ => Some(parse_attr_int_value(attr)?),
710+
};
711+
match execution_mode {
712+
OriginUpperLeft | OriginLowerLeft => {
713+
origin_mode.replace(*execution_mode);
714+
}
715+
LocalSize => {
716+
if local_size.is_none() {
717+
local_size.replace(parse_local_size_attr(attr)?);
718+
} else {
719+
return Err((
720+
attr_name.span,
721+
String::from(
722+
"`#[spirv(compute(threads))]` may only be specified once",
723+
),
724+
));
725+
}
726+
}
727+
LocalSizeHint => {
728+
let val = val.unwrap();
729+
if local_size_hint.is_none() {
730+
local_size_hint.replace([1, 1, 1]);
731+
}
732+
let local_size_hint = local_size_hint.as_mut().unwrap();
733+
match extra_dim {
734+
X => {
735+
local_size_hint[0] = val;
736+
}
737+
Y => {
738+
local_size_hint[1] = val;
739+
}
740+
Z => {
741+
local_size_hint[2] = val;
742+
}
743+
_ => unreachable!(),
744+
}
745+
}
746+
// Reserved
747+
/*MaxWorkgroupSizeINTEL => {
748+
let val = val.unwrap();
749+
if max_workgroup_size_intel.is_none() {
750+
max_workgroup_size_intel.replace([1, 1, 1]);
751+
}
752+
let max_workgroup_size_intel = max_workgroup_size_intel.as_mut()
753+
.unwrap();
754+
match extra_dim {
755+
X => {
756+
max_workgroup_size_intel[0] = val;
757+
},
758+
Y => {
759+
max_workgroup_size_intel[1] = val;
760+
},
761+
Z => {
762+
max_workgroup_size_intel[2] = val;
763+
},
764+
_ => unreachable!(),
765+
}
766+
},*/
767+
_ => {
768+
if let Some(val) = val {
769+
entry
770+
.execution_modes
771+
.push((*execution_mode, ExecutionModeExtra::new([val])));
772+
} else {
773+
entry
774+
.execution_modes
775+
.push((*execution_mode, ExecutionModeExtra::new([])));
776+
}
777+
}
778+
}
779+
} else if attr_name.name == sym.entry_point_name {
780+
match attr.value_str() {
781+
Some(sym) => {
782+
entry.name = Some(sym);
783+
}
784+
None => {
785+
return Err((
786+
attr_name.span,
787+
format!(
788+
"#[spirv({name}(..))] unknown attribute argument {attr_name}"
789+
),
790+
));
791+
}
792+
}
793+
} else {
794+
return Err((
795+
attr_name.span,
796+
format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",),
797+
));
798+
}
799+
} else {
800+
return Err((
801+
arg.span(),
802+
format!("#[spirv({name}(..))] attribute argument must be single identifier"),
803+
));
804+
}
805+
}
806+
}
807+
match entry.execution_model {
808+
Fragment => {
809+
let origin_mode = origin_mode.unwrap_or(OriginUpperLeft);
810+
entry
811+
.execution_modes
812+
.push((origin_mode, ExecutionModeExtra::new([])));
813+
}
814+
GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => {
815+
if let Some(local_size) = local_size {
816+
entry
817+
.execution_modes
818+
.push((LocalSize, ExecutionModeExtra::new(local_size)));
819+
} else {
820+
return Err((
821+
arg.span(),
822+
String::from(
823+
"The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`",
824+
),
825+
));
826+
}
827+
}
828+
//TODO: Cover more defaults
829+
_ => {}
830+
}
831+
Ok(entry)
832+
}

0 commit comments

Comments
 (0)