Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 208 additions & 14 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ use rustc_ast::ast::{InlineAsmOptions, InlineAsmTemplatePiece};
use rustc_codegen_ssa::mir::operand::OperandValue;
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
AsmBuilderMethods, BackendTypes, BuilderMethods, InlineAsmOperandRef,
AsmBuilderMethods, AsmCodegenMethods, BackendTypes, BuilderMethods, GlobalAsmOperandRef,
InlineAsmOperandRef,
};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_errors::{Diag, DiagMessage};
use rustc_middle::{bug, ty::Instance};
use rustc_span::{DUMMY_SP, Span};
use rustc_target::asm::{InlineAsmRegClass, InlineAsmRegOrRegClass, SpirVInlineAsmRegClass};
Expand Down Expand Up @@ -68,6 +70,153 @@ fn inline_asm_operand_ref_clone<'tcx, B: BackendTypes + ?Sized>(
}
}

impl<'tcx> AsmCodegenMethods<'tcx> for CodegenCx<'tcx> {
fn codegen_global_asm(
&mut self,
template: &[InlineAsmTemplatePiece],
operands: &[GlobalAsmOperandRef<'tcx>],
options: InlineAsmOptions,
line_spans: &[Span],
) {
const SUPPORTED_OPTIONS: InlineAsmOptions = InlineAsmOptions::empty();
let unsupported_options = options & !SUPPORTED_OPTIONS;
if !unsupported_options.is_empty() {
self.tcx.dcx().span_err(
line_spans.first().copied().unwrap_or_default(),
format!("global_asm! flags not supported: {unsupported_options:?}"),
);
}

// vec of lines, and each line is vec of tokens
let mut tokens = vec![vec![]];
for piece in template {
match piece {
InlineAsmTemplatePiece::String(asm) => {
// We cannot use str::lines() here because we don't want the behavior of "the
// last newline is optional", we want an empty string for the last line if
// there is no newline terminator.
// Lambda copied from std LinesAnyMap
let lines = asm.split('\n').map(|line| {
let l = line.len();
if l > 0 && line.as_bytes()[l - 1] == b'\r' {
&line[0..l - 1]
} else {
line
}
});
for (index, line) in lines.enumerate() {
if index != 0 {
// There was a newline, add a new line.
tokens.push(vec![]);
}
let mut chars = line.chars();

let span = line_spans
.get(tokens.len() - 1)
.copied()
.unwrap_or_default();
while let Some(token) = InlineAsmCx::Global(self, span).lex_word(&mut chars)
{
tokens.last_mut().unwrap().push(token);
}
}
}
&InlineAsmTemplatePiece::Placeholder {
operand_idx,
modifier,
span,
} => {
if let Some(modifier) = modifier {
self.tcx
.dcx()
.span_err(span, format!("asm modifiers are not supported: {modifier}"));
}
let span = line_spans
.get(tokens.len() - 1)
.copied()
.unwrap_or_default();
let line = tokens.last_mut().unwrap();
let typeof_kind = line.last().and_then(|prev| match prev {
Token::Word("typeof") => Some(TypeofKind::Plain),
Token::Word("typeof*") => Some(TypeofKind::Dereference),
_ => None,
});
let operand = &operands[operand_idx];
match typeof_kind {
Some(_) => match operand {
GlobalAsmOperandRef::Const { string: _ } => {
self.tcx
.dcx()
.span_err(span, "cannot take the type of a const asm argument");
}
GlobalAsmOperandRef::SymFn { instance: _ } => {
self.tcx.dcx().span_err(
span,
"cannot take the type of a function asm argument",
);
}
GlobalAsmOperandRef::SymStatic { def_id: _ } => {
self.tcx.dcx().span_err(
span,
"cannot take the type of a static variable asm argument",
);
}
},
None => match operand {
GlobalAsmOperandRef::Const { string } => line.push(Token::Word(string)),
GlobalAsmOperandRef::SymFn { instance: _ } => {
self.tcx
.dcx()
.span_err(span, "function asm argument not supported yet");
}
GlobalAsmOperandRef::SymStatic { def_id: _ } => {
self.tcx.dcx().span_err(
span,
"static variable asm argument not supported yet",
);
}
},
}
}
}
}

let mut id_map = FxHashMap::default();
let mut defined_ids = FxHashSet::default();
let mut id_to_type_map = FxHashMap::default();

let mut asm_block = AsmBlock::Open;
for (line_idx, line) in tokens.into_iter().enumerate() {
let span = line_spans.get(line_idx).copied().unwrap_or_default();
InlineAsmCx::Global(self, span).codegen_asm(
&mut id_map,
&mut defined_ids,
&mut id_to_type_map,
&mut asm_block,
line.into_iter(),
);
}

for (id, num) in id_map {
if !defined_ids.contains(&num) {
self.tcx.dcx().span_err(
line_spans.first().copied().unwrap_or_default(),
format!("%{id} is used but not defined"),
);
}
}
}

// FIXME(eddyb) should this method be implemented as just symbol mangling,
// or renamed upstream into something much more specific?
fn mangled_name(&self, instance: Instance<'tcx>) -> String {
self.tcx.dcx().span_bug(
self.tcx.def_span(instance.def_id()),
"[Rust-GPU] `#[naked] fn` not yet supported",
)
}
}

impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
/* Example asm and the template it compiles to:
asm!(
Expand Down Expand Up @@ -103,7 +252,7 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
const SUPPORTED_OPTIONS: InlineAsmOptions = InlineAsmOptions::NORETURN;
let unsupported_options = options & !SUPPORTED_OPTIONS;
if !unsupported_options.is_empty() {
self.err(format!("asm flags not supported: {unsupported_options:?}"));
self.err(format!("asm! flags not supported: {unsupported_options:?}"));
}

// HACK(eddyb) get more accurate pointers types, for pointer operands,
Expand Down Expand Up @@ -165,7 +314,7 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
tokens.push(vec![]);
}
let mut chars = line.chars();
while let Some(token) = self.lex_word(&mut chars) {
while let Some(token) = InlineAsmCx::Local(self).lex_word(&mut chars) {
tokens.last_mut().unwrap().push(token);
}
}
Expand Down Expand Up @@ -222,7 +371,7 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
// which is an unfortunate interaction, but perhaps avoidable?
AsmBlock::End(_) => {}
}
self.codegen_asm(
InlineAsmCx::Local(self).codegen_asm(
&mut id_map,
&mut defined_ids,
&mut id_to_type_map,
Expand Down Expand Up @@ -287,7 +436,47 @@ enum AsmBlock {
End(Op),
}

impl<'cx, 'tcx> Builder<'cx, 'tcx> {
enum InlineAsmCx<'a, 'cx, 'tcx> {
Global(&'cx CodegenCx<'tcx>, Span),
Local(&'a mut Builder<'cx, 'tcx>),
}

impl<'cx, 'tcx> std::ops::Deref for InlineAsmCx<'_, 'cx, 'tcx> {
type Target = &'cx CodegenCx<'tcx>;
fn deref(&self) -> &Self::Target {
match self {
Self::Global(cx, _) | Self::Local(Builder { cx, .. }) => cx,
}
}
}

impl InlineAsmCx<'_, '_, '_> {
fn span(&self) -> Span {
match self {
&Self::Global(_, span) => span,
Self::Local(bx) => bx.span(),
}
}

#[track_caller]
fn struct_err(&self, msg: impl Into<DiagMessage>) -> Diag<'_> {
self.tcx.dcx().struct_span_err(self.span(), msg)
}

#[track_caller]
fn err(&self, msg: impl Into<DiagMessage>) {
self.tcx.dcx().span_err(self.span(), msg);
}

fn emit(&mut self) -> std::cell::RefMut<'_, rspirv::dr::Builder> {
match self {
Self::Global(cx, _) => cx.emit_global(),
Self::Local(bx) => bx.emit(),
}
}
}

impl<'cx, 'tcx> InlineAsmCx<'_, 'cx, 'tcx> {
fn lex_word<'a>(&self, line: &mut std::str::Chars<'a>) -> Option<Token<'a, 'cx, 'tcx>> {
loop {
let start = line.as_str();
Expand Down Expand Up @@ -553,7 +742,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
};
let inst_class = inst_name
.strip_prefix("Op")
.and_then(|n| self.cx.instruction_table.table.get(n));
.and_then(|n| self.instruction_table.table.get(n));
let inst_class = if let Some(inst) = inst_class {
inst
} else {
Expand All @@ -577,13 +766,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
}
self.insert_inst(id_map, defined_ids, asm_block, instruction);
if let Some(OutRegister::Place(place)) = out_register {
let place = match self {
Self::Global(..) => unreachable!(),
Self::Local(bx) => place.val.llval.def(bx),
};
self.emit()
.store(
place.val.llval.def(self),
result_id.unwrap(),
None,
std::iter::empty(),
)
.store(place, result_id.unwrap(), None, std::iter::empty())
.unwrap();
}
}
Expand Down Expand Up @@ -1093,7 +1281,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
Token::Placeholder(hole, span) => match hole {
InlineAsmOperandRef::In { reg, value } => {
self.check_reg(span, reg);
Some(value.immediate().def(self))
match self {
Self::Global(..) => unreachable!(),
Self::Local(bx) => Some(value.immediate().def(bx)),
}
}
InlineAsmOperandRef::Out {
reg,
Expand All @@ -1113,7 +1304,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
out_place: _,
} => {
self.check_reg(span, reg);
Some(in_value.immediate().def(self))
match self {
Self::Global(..) => unreachable!(),
Self::Local(bx) => Some(in_value.immediate().def(bx)),
}
}
InlineAsmOperandRef::Const { string: _ } => {
self.tcx
Expand Down
30 changes: 1 addition & 29 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ use itertools::Itertools as _;
use rspirv::dr::{Module, Operand};
use rspirv::spirv::{Decoration, LinkageType, Word};
use rustc_abi::{AddressSpace, HasDataLayout, TargetDataLayout};
use rustc_ast::ast::{InlineAsmOptions, InlineAsmTemplatePiece};
use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind};
use rustc_codegen_ssa::traits::{
AsmCodegenMethods, BackendTypes, DebugInfoCodegenMethods, GlobalAsmOperandRef,
MiscCodegenMethods,
};
use rustc_codegen_ssa::traits::{BackendTypes, DebugInfoCodegenMethods, MiscCodegenMethods};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_hir::def_id::DefId;
use rustc_middle::mir;
Expand Down Expand Up @@ -934,27 +930,3 @@ impl<'tcx> DebugInfoCodegenMethods<'tcx> for CodegenCx<'tcx> {
todo!()
}
}

impl<'tcx> AsmCodegenMethods<'tcx> for CodegenCx<'tcx> {
fn codegen_global_asm(
&mut self,
_template: &[InlineAsmTemplatePiece],
_operands: &[GlobalAsmOperandRef<'tcx>],
_options: InlineAsmOptions,
line_spans: &[Span],
) {
self.tcx.dcx().span_fatal(
line_spans.first().copied().unwrap_or_default(),
"[Rust-GPU] `global_asm!` not yet supported",
);
}

// FIXME(eddyb) should this method be implemented as just symbol mangling,
// or renamed upstream into something much more specific?
fn mangled_name(&self, instance: Instance<'tcx>) -> String {
self.tcx.dcx().span_bug(
self.tcx.def_span(instance.def_id()),
"[Rust-GPU] `#[naked] fn` not yet supported",
)
}
}
Loading