|
5 | 5 | use crate::codegen_cx::CodegenCx;
|
6 | 6 | use crate::symbols::Symbols;
|
7 | 7 | use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
|
| 8 | +use rustc_ast::{LitKind, MetaItemInner, MetaItemLit}; |
8 | 9 | use rustc_hir as hir;
|
9 | 10 | use rustc_hir::def_id::LocalModDefId;
|
10 | 11 | use rustc_hir::intravisit::{self, Visitor};
|
11 | 12 | use rustc_hir::{Attribute, CRATE_HIR_ID, HirId, MethodKind, Target};
|
12 | 13 | use rustc_middle::hir::nested_filter;
|
13 | 14 | use rustc_middle::query::Providers;
|
14 | 15 | use rustc_middle::ty::TyCtxt;
|
15 |
| -use rustc_span::{Span, Symbol}; |
| 16 | +use rustc_span::{Ident, Span, Symbol}; |
16 | 17 | use std::rc::Rc;
|
17 | 18 |
|
18 | 19 | // FIXME(eddyb) replace with `ArrayVec<[Word; 3]>`.
|
@@ -152,7 +153,7 @@ impl AggregatedSpirvAttributes {
|
152 | 153 |
|
153 | 154 | // NOTE(eddyb) `span_delayed_bug` ensures that if attribute checking fails
|
154 | 155 | // to see an attribute error, it will cause an ICE instead.
|
155 |
| - for parse_attr_result in crate::symbols::parse_attrs_for_checking(&cx.sym, attrs) { |
| 156 | + for parse_attr_result in parse_attrs_for_checking(&cx.sym, attrs) { |
156 | 157 | let (span, parsed_attr) = match parse_attr_result {
|
157 | 158 | Ok(span_and_parsed_attr) => span_and_parsed_attr,
|
158 | 159 | Err((span, msg)) => {
|
@@ -278,7 +279,7 @@ impl CheckSpirvAttrVisitor<'_> {
|
278 | 279 | fn check_spirv_attributes(&self, hir_id: HirId, target: Target) {
|
279 | 280 | let mut aggregated_attrs = AggregatedSpirvAttributes::default();
|
280 | 281 |
|
281 |
| - let parse_attrs = |attrs| crate::symbols::parse_attrs_for_checking(&self.sym, attrs); |
| 282 | + let parse_attrs = |attrs| parse_attrs_for_checking(&self.sym, attrs); |
282 | 283 |
|
283 | 284 | let attrs = self.tcx.hir_attrs(hir_id);
|
284 | 285 | for parse_attr_result in parse_attrs(attrs) {
|
@@ -512,3 +513,320 @@ pub(crate) fn provide(providers: &mut Providers) {
|
512 | 513 | ..*providers
|
513 | 514 | };
|
514 | 515 | }
|
| 516 | + |
| 517 | +// FIXME(eddyb) find something nicer for the error type. |
| 518 | +type ParseAttrError = (Span, String); |
| 519 | + |
| 520 | +fn parse_attrs_for_checking<'a>( |
| 521 | + sym: &'a Symbols, |
| 522 | + attrs: &'a [Attribute], |
| 523 | +) -> impl Iterator<Item = Result<(Span, SpirvAttribute), ParseAttrError>> + 'a { |
| 524 | + attrs.iter().flat_map(move |attr| { |
| 525 | + let (whole_attr_error, args) = match attr { |
| 526 | + Attribute::Unparsed(item) => { |
| 527 | + // #[...] |
| 528 | + let s = &item.path.segments; |
| 529 | + if s.len() > 1 && s[0].name == sym.rust_gpu { |
| 530 | + // #[rust_gpu ...] |
| 531 | + if s.len() != 2 || s[1].name != sym.spirv { |
| 532 | + // #[rust_gpu::...] but not #[rust_gpu::spirv] |
| 533 | + ( |
| 534 | + Some(Err(( |
| 535 | + attr.span(), |
| 536 | + "unknown `rust_gpu` attribute, expected `rust_gpu::spirv`" |
| 537 | + .to_string(), |
| 538 | + ))), |
| 539 | + Default::default(), |
| 540 | + ) |
| 541 | + } else if let Some(args) = attr.meta_item_list() { |
| 542 | + // #[rust_gpu::spirv(...)] |
| 543 | + (None, args) |
| 544 | + } else { |
| 545 | + // #[rust_gpu::spirv] |
| 546 | + ( |
| 547 | + Some(Err(( |
| 548 | + attr.span(), |
| 549 | + "#[rust_gpu::spirv(..)] attribute must have at least one argument" |
| 550 | + .to_string(), |
| 551 | + ))), |
| 552 | + Default::default(), |
| 553 | + ) |
| 554 | + } |
| 555 | + } else { |
| 556 | + // #[...] but not #[rust_gpu ...] |
| 557 | + (None, Default::default()) |
| 558 | + } |
| 559 | + } |
| 560 | + Attribute::Parsed(_) => (None, Default::default()), |
| 561 | + }; |
| 562 | + |
| 563 | + whole_attr_error |
| 564 | + .into_iter() |
| 565 | + .chain(args.into_iter().map(move |ref arg| { |
| 566 | + let span = arg.span(); |
| 567 | + let parsed_attr = if arg.has_name(sym.descriptor_set) { |
| 568 | + SpirvAttribute::DescriptorSet(parse_attr_int_value(arg)?) |
| 569 | + } else if arg.has_name(sym.binding) { |
| 570 | + SpirvAttribute::Binding(parse_attr_int_value(arg)?) |
| 571 | + } else if arg.has_name(sym.input_attachment_index) { |
| 572 | + SpirvAttribute::InputAttachmentIndex(parse_attr_int_value(arg)?) |
| 573 | + } else if arg.has_name(sym.spec_constant) { |
| 574 | + SpirvAttribute::SpecConstant(parse_spec_constant_attr(sym, arg)?) |
| 575 | + } else { |
| 576 | + let name = match arg.ident() { |
| 577 | + Some(i) => i, |
| 578 | + None => { |
| 579 | + return Err(( |
| 580 | + span, |
| 581 | + "#[spirv(..)] attribute argument must be single identifier" |
| 582 | + .to_string(), |
| 583 | + )); |
| 584 | + } |
| 585 | + }; |
| 586 | + sym.attributes.get(&name.name).map_or_else( |
| 587 | + || Err((name.span, "unknown argument to spirv attribute".to_string())), |
| 588 | + |a| { |
| 589 | + Ok(match a { |
| 590 | + SpirvAttribute::Entry(entry) => SpirvAttribute::Entry( |
| 591 | + parse_entry_attrs(sym, arg, &name, entry.execution_model)?, |
| 592 | + ), |
| 593 | + _ => a.clone(), |
| 594 | + }) |
| 595 | + }, |
| 596 | + )? |
| 597 | + }; |
| 598 | + Ok((span, parsed_attr)) |
| 599 | + })) |
| 600 | + }) |
| 601 | +} |
| 602 | + |
| 603 | +fn parse_spec_constant_attr( |
| 604 | + sym: &Symbols, |
| 605 | + arg: &MetaItemInner, |
| 606 | +) -> Result<SpecConstant, ParseAttrError> { |
| 607 | + let mut id = None; |
| 608 | + let mut default = None; |
| 609 | + |
| 610 | + if let Some(attrs) = arg.meta_item_list() { |
| 611 | + for attr in attrs { |
| 612 | + if attr.has_name(sym.id) { |
| 613 | + if id.is_none() { |
| 614 | + id = Some(parse_attr_int_value(attr)?); |
| 615 | + } else { |
| 616 | + return Err((attr.span(), "`id` may only be specified once".into())); |
| 617 | + } |
| 618 | + } else if attr.has_name(sym.default) { |
| 619 | + if default.is_none() { |
| 620 | + default = Some(parse_attr_int_value(attr)?); |
| 621 | + } else { |
| 622 | + return Err((attr.span(), "`default` may only be specified once".into())); |
| 623 | + } |
| 624 | + } else { |
| 625 | + return Err((attr.span(), "expected `id = ...` or `default = ...`".into())); |
| 626 | + } |
| 627 | + } |
| 628 | + } |
| 629 | + Ok(SpecConstant { |
| 630 | + id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?, |
| 631 | + default, |
| 632 | + }) |
| 633 | +} |
| 634 | + |
| 635 | +fn parse_attr_int_value(arg: &MetaItemInner) -> Result<u32, ParseAttrError> { |
| 636 | + let arg = match arg.meta_item() { |
| 637 | + Some(arg) => arg, |
| 638 | + None => return Err((arg.span(), "attribute must have value".to_string())), |
| 639 | + }; |
| 640 | + match arg.name_value_literal() { |
| 641 | + Some(&MetaItemLit { |
| 642 | + kind: LitKind::Int(x, ..), |
| 643 | + .. |
| 644 | + }) if x <= u32::MAX as u128 => Ok(x.get() as u32), |
| 645 | + _ => Err((arg.span, "attribute value must be integer".to_string())), |
| 646 | + } |
| 647 | +} |
| 648 | + |
| 649 | +fn parse_local_size_attr(arg: &MetaItemInner) -> Result<[u32; 3], ParseAttrError> { |
| 650 | + let arg = match arg.meta_item() { |
| 651 | + Some(arg) => arg, |
| 652 | + None => return Err((arg.span(), "attribute must have value".to_string())), |
| 653 | + }; |
| 654 | + match arg.meta_item_list() { |
| 655 | + Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => { |
| 656 | + let mut local_size = [1; 3]; |
| 657 | + for (idx, lit) in tuple.iter().enumerate() { |
| 658 | + match lit { |
| 659 | + MetaItemInner::Lit(MetaItemLit { |
| 660 | + kind: LitKind::Int(x, ..), |
| 661 | + .. |
| 662 | + }) if *x <= u32::MAX as u128 => local_size[idx] = x.get() as u32, |
| 663 | + _ => return Err((lit.span(), "must be a u32 literal".to_string())), |
| 664 | + } |
| 665 | + } |
| 666 | + Ok(local_size) |
| 667 | + } |
| 668 | + Some([]) => Err(( |
| 669 | + arg.span, |
| 670 | + "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(), |
| 671 | + )), |
| 672 | + Some(tuple) if tuple.len() > 3 => Err(( |
| 673 | + arg.span, |
| 674 | + "#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(), |
| 675 | + )), |
| 676 | + _ => Err(( |
| 677 | + arg.span, |
| 678 | + "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(), |
| 679 | + )), |
| 680 | + } |
| 681 | +} |
| 682 | + |
| 683 | +// for a given entry, gather up the additional attributes |
| 684 | +// in this case ExecutionMode's, some have extra arguments |
| 685 | +// others are specified with x, y, or z components |
| 686 | +// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))] |
| 687 | +fn parse_entry_attrs( |
| 688 | + sym: &Symbols, |
| 689 | + arg: &MetaItemInner, |
| 690 | + name: &Ident, |
| 691 | + execution_model: ExecutionModel, |
| 692 | +) -> Result<Entry, ParseAttrError> { |
| 693 | + use ExecutionMode::*; |
| 694 | + use ExecutionModel::*; |
| 695 | + let mut entry = Entry::from(execution_model); |
| 696 | + let mut origin_mode: Option<ExecutionMode> = None; |
| 697 | + let mut local_size: Option<[u32; 3]> = None; |
| 698 | + let mut local_size_hint: Option<[u32; 3]> = None; |
| 699 | + // Reserved |
| 700 | + //let mut max_workgroup_size_intel: Option<[u32; 3]> = None; |
| 701 | + if let Some(attrs) = arg.meta_item_list() { |
| 702 | + for attr in attrs { |
| 703 | + if let Some(attr_name) = attr.ident() { |
| 704 | + if let Some((execution_mode, extra_dim)) = sym.execution_modes.get(&attr_name.name) |
| 705 | + { |
| 706 | + use crate::symbols::ExecutionModeExtraDim::*; |
| 707 | + let val = match extra_dim { |
| 708 | + None | Tuple => Option::None, |
| 709 | + _ => Some(parse_attr_int_value(attr)?), |
| 710 | + }; |
| 711 | + match execution_mode { |
| 712 | + OriginUpperLeft | OriginLowerLeft => { |
| 713 | + origin_mode.replace(*execution_mode); |
| 714 | + } |
| 715 | + LocalSize => { |
| 716 | + if local_size.is_none() { |
| 717 | + local_size.replace(parse_local_size_attr(attr)?); |
| 718 | + } else { |
| 719 | + return Err(( |
| 720 | + attr_name.span, |
| 721 | + String::from( |
| 722 | + "`#[spirv(compute(threads))]` may only be specified once", |
| 723 | + ), |
| 724 | + )); |
| 725 | + } |
| 726 | + } |
| 727 | + LocalSizeHint => { |
| 728 | + let val = val.unwrap(); |
| 729 | + if local_size_hint.is_none() { |
| 730 | + local_size_hint.replace([1, 1, 1]); |
| 731 | + } |
| 732 | + let local_size_hint = local_size_hint.as_mut().unwrap(); |
| 733 | + match extra_dim { |
| 734 | + X => { |
| 735 | + local_size_hint[0] = val; |
| 736 | + } |
| 737 | + Y => { |
| 738 | + local_size_hint[1] = val; |
| 739 | + } |
| 740 | + Z => { |
| 741 | + local_size_hint[2] = val; |
| 742 | + } |
| 743 | + _ => unreachable!(), |
| 744 | + } |
| 745 | + } |
| 746 | + // Reserved |
| 747 | + /*MaxWorkgroupSizeINTEL => { |
| 748 | + let val = val.unwrap(); |
| 749 | + if max_workgroup_size_intel.is_none() { |
| 750 | + max_workgroup_size_intel.replace([1, 1, 1]); |
| 751 | + } |
| 752 | + let max_workgroup_size_intel = max_workgroup_size_intel.as_mut() |
| 753 | + .unwrap(); |
| 754 | + match extra_dim { |
| 755 | + X => { |
| 756 | + max_workgroup_size_intel[0] = val; |
| 757 | + }, |
| 758 | + Y => { |
| 759 | + max_workgroup_size_intel[1] = val; |
| 760 | + }, |
| 761 | + Z => { |
| 762 | + max_workgroup_size_intel[2] = val; |
| 763 | + }, |
| 764 | + _ => unreachable!(), |
| 765 | + } |
| 766 | + },*/ |
| 767 | + _ => { |
| 768 | + if let Some(val) = val { |
| 769 | + entry |
| 770 | + .execution_modes |
| 771 | + .push((*execution_mode, ExecutionModeExtra::new([val]))); |
| 772 | + } else { |
| 773 | + entry |
| 774 | + .execution_modes |
| 775 | + .push((*execution_mode, ExecutionModeExtra::new([]))); |
| 776 | + } |
| 777 | + } |
| 778 | + } |
| 779 | + } else if attr_name.name == sym.entry_point_name { |
| 780 | + match attr.value_str() { |
| 781 | + Some(sym) => { |
| 782 | + entry.name = Some(sym); |
| 783 | + } |
| 784 | + None => { |
| 785 | + return Err(( |
| 786 | + attr_name.span, |
| 787 | + format!( |
| 788 | + "#[spirv({name}(..))] unknown attribute argument {attr_name}" |
| 789 | + ), |
| 790 | + )); |
| 791 | + } |
| 792 | + } |
| 793 | + } else { |
| 794 | + return Err(( |
| 795 | + attr_name.span, |
| 796 | + format!("#[spirv({name}(..))] unknown attribute argument {attr_name}",), |
| 797 | + )); |
| 798 | + } |
| 799 | + } else { |
| 800 | + return Err(( |
| 801 | + arg.span(), |
| 802 | + format!("#[spirv({name}(..))] attribute argument must be single identifier"), |
| 803 | + )); |
| 804 | + } |
| 805 | + } |
| 806 | + } |
| 807 | + match entry.execution_model { |
| 808 | + Fragment => { |
| 809 | + let origin_mode = origin_mode.unwrap_or(OriginUpperLeft); |
| 810 | + entry |
| 811 | + .execution_modes |
| 812 | + .push((origin_mode, ExecutionModeExtra::new([]))); |
| 813 | + } |
| 814 | + GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => { |
| 815 | + if let Some(local_size) = local_size { |
| 816 | + entry |
| 817 | + .execution_modes |
| 818 | + .push((LocalSize, ExecutionModeExtra::new(local_size))); |
| 819 | + } else { |
| 820 | + return Err(( |
| 821 | + arg.span(), |
| 822 | + String::from( |
| 823 | + "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`", |
| 824 | + ), |
| 825 | + )); |
| 826 | + } |
| 827 | + } |
| 828 | + //TODO: Cover more defaults |
| 829 | + _ => {} |
| 830 | + } |
| 831 | + Ok(entry) |
| 832 | +} |
0 commit comments