diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 8441d74c08..c9d51dec1a 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -18,9 +18,11 @@ use rustc_ast::ast::{InlineAsmOptions, InlineAsmTemplatePiece}; use rustc_codegen_ssa::mir::operand::OperandValue; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{ - AsmBuilderMethods, BackendTypes, BuilderMethods, InlineAsmOperandRef, + AsmBuilderMethods, AsmCodegenMethods, BackendTypes, BuilderMethods, GlobalAsmOperandRef, + InlineAsmOperandRef, }; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_errors::{Diag, DiagMessage}; use rustc_middle::{bug, ty::Instance}; use rustc_span::{DUMMY_SP, Span}; use rustc_target::asm::{InlineAsmRegClass, InlineAsmRegOrRegClass, SpirVInlineAsmRegClass}; @@ -68,6 +70,153 @@ fn inline_asm_operand_ref_clone<'tcx, B: BackendTypes + ?Sized>( } } +impl<'tcx> AsmCodegenMethods<'tcx> for CodegenCx<'tcx> { + fn codegen_global_asm( + &mut self, + template: &[InlineAsmTemplatePiece], + operands: &[GlobalAsmOperandRef<'tcx>], + options: InlineAsmOptions, + line_spans: &[Span], + ) { + const SUPPORTED_OPTIONS: InlineAsmOptions = InlineAsmOptions::empty(); + let unsupported_options = options & !SUPPORTED_OPTIONS; + if !unsupported_options.is_empty() { + self.tcx.dcx().span_err( + line_spans.first().copied().unwrap_or_default(), + format!("global_asm! flags not supported: {unsupported_options:?}"), + ); + } + + // vec of lines, and each line is vec of tokens + let mut tokens = vec![vec![]]; + for piece in template { + match piece { + InlineAsmTemplatePiece::String(asm) => { + // We cannot use str::lines() here because we don't want the behavior of "the + // last newline is optional", we want an empty string for the last line if + // there is no newline terminator. + // Lambda copied from std LinesAnyMap + let lines = asm.split('\n').map(|line| { + let l = line.len(); + if l > 0 && line.as_bytes()[l - 1] == b'\r' { + &line[0..l - 1] + } else { + line + } + }); + for (index, line) in lines.enumerate() { + if index != 0 { + // There was a newline, add a new line. + tokens.push(vec![]); + } + let mut chars = line.chars(); + + let span = line_spans + .get(tokens.len() - 1) + .copied() + .unwrap_or_default(); + while let Some(token) = InlineAsmCx::Global(self, span).lex_word(&mut chars) + { + tokens.last_mut().unwrap().push(token); + } + } + } + &InlineAsmTemplatePiece::Placeholder { + operand_idx, + modifier, + span, + } => { + if let Some(modifier) = modifier { + self.tcx + .dcx() + .span_err(span, format!("asm modifiers are not supported: {modifier}")); + } + let span = line_spans + .get(tokens.len() - 1) + .copied() + .unwrap_or_default(); + let line = tokens.last_mut().unwrap(); + let typeof_kind = line.last().and_then(|prev| match prev { + Token::Word("typeof") => Some(TypeofKind::Plain), + Token::Word("typeof*") => Some(TypeofKind::Dereference), + _ => None, + }); + let operand = &operands[operand_idx]; + match typeof_kind { + Some(_) => match operand { + GlobalAsmOperandRef::Const { string: _ } => { + self.tcx + .dcx() + .span_err(span, "cannot take the type of a const asm argument"); + } + GlobalAsmOperandRef::SymFn { instance: _ } => { + self.tcx.dcx().span_err( + span, + "cannot take the type of a function asm argument", + ); + } + GlobalAsmOperandRef::SymStatic { def_id: _ } => { + self.tcx.dcx().span_err( + span, + "cannot take the type of a static variable asm argument", + ); + } + }, + None => match operand { + GlobalAsmOperandRef::Const { string } => line.push(Token::Word(string)), + GlobalAsmOperandRef::SymFn { instance: _ } => { + self.tcx + .dcx() + .span_err(span, "function asm argument not supported yet"); + } + GlobalAsmOperandRef::SymStatic { def_id: _ } => { + self.tcx.dcx().span_err( + span, + "static variable asm argument not supported yet", + ); + } + }, + } + } + } + } + + let mut id_map = FxHashMap::default(); + let mut defined_ids = FxHashSet::default(); + let mut id_to_type_map = FxHashMap::default(); + + let mut asm_block = AsmBlock::Open; + for (line_idx, line) in tokens.into_iter().enumerate() { + let span = line_spans.get(line_idx).copied().unwrap_or_default(); + InlineAsmCx::Global(self, span).codegen_asm( + &mut id_map, + &mut defined_ids, + &mut id_to_type_map, + &mut asm_block, + line.into_iter(), + ); + } + + for (id, num) in id_map { + if !defined_ids.contains(&num) { + self.tcx.dcx().span_err( + line_spans.first().copied().unwrap_or_default(), + format!("%{id} is used but not defined"), + ); + } + } + } + + // FIXME(eddyb) should this method be implemented as just symbol mangling, + // or renamed upstream into something much more specific? + fn mangled_name(&self, instance: Instance<'tcx>) -> String { + self.tcx.dcx().span_bug( + self.tcx.def_span(instance.def_id()), + "[Rust-GPU] `#[naked] fn` not yet supported", + ) + } +} + impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> { /* Example asm and the template it compiles to: asm!( @@ -103,7 +252,7 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> { const SUPPORTED_OPTIONS: InlineAsmOptions = InlineAsmOptions::NORETURN; let unsupported_options = options & !SUPPORTED_OPTIONS; if !unsupported_options.is_empty() { - self.err(format!("asm flags not supported: {unsupported_options:?}")); + self.err(format!("asm! flags not supported: {unsupported_options:?}")); } // HACK(eddyb) get more accurate pointers types, for pointer operands, @@ -165,7 +314,7 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> { tokens.push(vec![]); } let mut chars = line.chars(); - while let Some(token) = self.lex_word(&mut chars) { + while let Some(token) = InlineAsmCx::Local(self).lex_word(&mut chars) { tokens.last_mut().unwrap().push(token); } } @@ -222,7 +371,7 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> { // which is an unfortunate interaction, but perhaps avoidable? AsmBlock::End(_) => {} } - self.codegen_asm( + InlineAsmCx::Local(self).codegen_asm( &mut id_map, &mut defined_ids, &mut id_to_type_map, @@ -287,7 +436,47 @@ enum AsmBlock { End(Op), } -impl<'cx, 'tcx> Builder<'cx, 'tcx> { +enum InlineAsmCx<'a, 'cx, 'tcx> { + Global(&'cx CodegenCx<'tcx>, Span), + Local(&'a mut Builder<'cx, 'tcx>), +} + +impl<'cx, 'tcx> std::ops::Deref for InlineAsmCx<'_, 'cx, 'tcx> { + type Target = &'cx CodegenCx<'tcx>; + fn deref(&self) -> &Self::Target { + match self { + Self::Global(cx, _) | Self::Local(Builder { cx, .. }) => cx, + } + } +} + +impl InlineAsmCx<'_, '_, '_> { + fn span(&self) -> Span { + match self { + &Self::Global(_, span) => span, + Self::Local(bx) => bx.span(), + } + } + + #[track_caller] + fn struct_err(&self, msg: impl Into) -> Diag<'_> { + self.tcx.dcx().struct_span_err(self.span(), msg) + } + + #[track_caller] + fn err(&self, msg: impl Into) { + self.tcx.dcx().span_err(self.span(), msg); + } + + fn emit(&mut self) -> std::cell::RefMut<'_, rspirv::dr::Builder> { + match self { + Self::Global(cx, _) => cx.emit_global(), + Self::Local(bx) => bx.emit(), + } + } +} + +impl<'cx, 'tcx> InlineAsmCx<'_, 'cx, 'tcx> { fn lex_word<'a>(&self, line: &mut std::str::Chars<'a>) -> Option> { loop { let start = line.as_str(); @@ -553,7 +742,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { }; let inst_class = inst_name .strip_prefix("Op") - .and_then(|n| self.cx.instruction_table.table.get(n)); + .and_then(|n| self.instruction_table.table.get(n)); let inst_class = if let Some(inst) = inst_class { inst } else { @@ -577,13 +766,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { } self.insert_inst(id_map, defined_ids, asm_block, instruction); if let Some(OutRegister::Place(place)) = out_register { + let place = match self { + Self::Global(..) => unreachable!(), + Self::Local(bx) => place.val.llval.def(bx), + }; self.emit() - .store( - place.val.llval.def(self), - result_id.unwrap(), - None, - std::iter::empty(), - ) + .store(place, result_id.unwrap(), None, std::iter::empty()) .unwrap(); } } @@ -1093,7 +1281,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { Token::Placeholder(hole, span) => match hole { InlineAsmOperandRef::In { reg, value } => { self.check_reg(span, reg); - Some(value.immediate().def(self)) + match self { + Self::Global(..) => unreachable!(), + Self::Local(bx) => Some(value.immediate().def(bx)), + } } InlineAsmOperandRef::Out { reg, @@ -1113,7 +1304,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { out_place: _, } => { self.check_reg(span, reg); - Some(in_value.immediate().def(self)) + match self { + Self::Global(..) => unreachable!(), + Self::Local(bx) => Some(in_value.immediate().def(bx)), + } } InlineAsmOperandRef::Const { string: _ } => { self.tcx diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index f3811f8a19..efa85db951 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -19,12 +19,8 @@ use itertools::Itertools as _; use rspirv::dr::{Module, Operand}; use rspirv::spirv::{Decoration, LinkageType, Word}; use rustc_abi::{AddressSpace, HasDataLayout, TargetDataLayout}; -use rustc_ast::ast::{InlineAsmOptions, InlineAsmTemplatePiece}; use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind}; -use rustc_codegen_ssa::traits::{ - AsmCodegenMethods, BackendTypes, DebugInfoCodegenMethods, GlobalAsmOperandRef, - MiscCodegenMethods, -}; +use rustc_codegen_ssa::traits::{BackendTypes, DebugInfoCodegenMethods, MiscCodegenMethods}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_hir::def_id::DefId; use rustc_middle::mir; @@ -934,27 +930,3 @@ impl<'tcx> DebugInfoCodegenMethods<'tcx> for CodegenCx<'tcx> { todo!() } } - -impl<'tcx> AsmCodegenMethods<'tcx> for CodegenCx<'tcx> { - fn codegen_global_asm( - &mut self, - _template: &[InlineAsmTemplatePiece], - _operands: &[GlobalAsmOperandRef<'tcx>], - _options: InlineAsmOptions, - line_spans: &[Span], - ) { - self.tcx.dcx().span_fatal( - line_spans.first().copied().unwrap_or_default(), - "[Rust-GPU] `global_asm!` not yet supported", - ); - } - - // FIXME(eddyb) should this method be implemented as just symbol mangling, - // or renamed upstream into something much more specific? - fn mangled_name(&self, instance: Instance<'tcx>) -> String { - self.tcx.dcx().span_bug( - self.tcx.def_span(instance.def_id()), - "[Rust-GPU] `#[naked] fn` not yet supported", - ) - } -}